@@ -382,6 +382,13 @@ function link(job, compiled)
382382 return compiled
383383end
384384
385+ function transpose_val (val)
386+ attr = MLIR. IR. DenseArrayAttribute (
387+ Int64[reverse (0 : (length (size (MLIR. IR. type (val))) - 1 ))... ]
388+ )
389+ return MLIR. IR. result (MLIR. Dialects. stablehlo. transpose (val; permutation= attr), 1 )
390+ end
391+
385392function (func:: LLVMFunc{F,tt} )(args... ; convert= Val (false ), blocks:: CuDim = 1 , threads:: CuDim = 1 ,
386393 cooperative:: Bool = false , shmem:: Integer = 0 , call_kwargs... ) where {F, tt}
387394 @show args
@@ -392,32 +399,33 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th
392399
393400 mlir_args = MLIR. IR. Value[]
394401 restys = MLIR. IR. Type[]
395- aliases = MLIR. API . MlirAttribute []
402+ aliases = MLIR. IR . Attribute []
396403 rarrays = TracedRArray[]
397404 for (i, a) in enumerate (args)
398405 @show a
399406 @assert a isa CuTracedArray
400407 ta = Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, a. ptr)):: TracedRArray
401408 push! (rarrays, ta)
402409 arg = ta. mlir_data
403- arg = Reactant. Compiler. transpose_val (arg)
404- push! (restys, MLIR. IR. Type (arg))
410+ arg = transpose_val (arg)
411+ @show arg
412+ push! (restys, MLIR. IR. type (arg))
405413 push! (aliases,
406- MLIR. IR. Dialects . stablehlo . stablehloOutputOperandAliasGet (
414+ MLIR. IR. Attribute (MLIR . API . stablehloOutputOperandAliasGet (
407415 MLIR. IR. context (),
408- len (args) == 1 ? 0 : 1 ,
409- len (args) == 1 ? C_NULL : Ref {Int64} (i- 1 ),
416+ length (args) == 1 ? 0 : 1 ,
417+ length (args) == 1 ? C_NULL : Ref {Int64} (i- 1 ),
410418 i- 1 ,
411419 0 ,
412420 C_NULL
413- )
421+ ) )
414422 )
415423 end
416424
417- output_operand_aliases= MLIR. ArrayAttr . get (MLIR . IR. context (), aliases)
418- call = MLIR. IR . Dialects. stablehlo. custom_call (mlir_args; result_0= restys, call_target_name= " reactant_gpu_call" , output_operand_aliases)
425+ output_operand_aliases= MLIR. IR. Attribute ( aliases)
426+ call = MLIR. Dialects. stablehlo. custom_call (mlir_args; result_0= restys, call_target_name= " reactant_gpu_call" , output_operand_aliases)
419427 for (i, res) in enumerate (rarrays)
420- ta . mlir_data = Reactant . Compiler . transpose_val (MLIR. IR. result (call, i- 1 ))
428+ res . mlir_data = transpose_val (MLIR. IR. result (call, i))
421429 end
422430 # CUDA.cuLaunchKernel(f,
423431 # blockdim.x, blockdim.y, blockdim.z,
0 commit comments