Skip to content

Commit d149c85

Browse files
author
William Moses
committed
continuing
1 parent c7afab7 commit d149c85

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,13 @@ function link(job, compiled)
382382
return compiled
383383
end
384384

385+
function transpose_val(val)
386+
attr = MLIR.IR.DenseArrayAttribute(
387+
Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...]
388+
)
389+
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
390+
end
391+
385392
function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
386393
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
387394
@show args
@@ -392,32 +399,33 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th
392399

393400
mlir_args = MLIR.IR.Value[]
394401
restys = MLIR.IR.Type[]
395-
aliases = MLIR.API.MlirAttribute[]
402+
aliases = MLIR.IR.Attribute[]
396403
rarrays = TracedRArray[]
397404
for (i, a) in enumerate(args)
398405
@show a
399406
@assert a isa CuTracedArray
400407
ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
401408
push!(rarrays, ta)
402409
arg = ta.mlir_data
403-
arg = Reactant.Compiler.transpose_val(arg)
404-
push!(restys, MLIR.IR.Type(arg))
410+
arg = transpose_val(arg)
411+
@show arg
412+
push!(restys, MLIR.IR.type(arg))
405413
push!(aliases,
406-
MLIR.IR.Dialects.stablehlo.stablehloOutputOperandAliasGet(
414+
MLIR.IR.Attribute(MLIR.API.stablehloOutputOperandAliasGet(
407415
MLIR.IR.context(),
408-
len(args) == 1 ? 0 : 1,
409-
len(args) == 1 ? C_NULL : Ref{Int64}(i-1),
416+
length(args) == 1 ? 0 : 1,
417+
length(args) == 1 ? C_NULL : Ref{Int64}(i-1),
410418
i-1,
411419
0,
412420
C_NULL
413-
)
421+
))
414422
)
415423
end
416424

417-
output_operand_aliases=MLIR.ArrayAttr.get(MLIR.IR.context(), aliases)
418-
call = MLIR.IR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases)
425+
output_operand_aliases=MLIR.IR.Attribute(aliases)
426+
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases)
419427
for (i, res) in enumerate(rarrays)
420-
ta.mlir_data = Reactant.Compiler.transpose_val(MLIR.IR.result(call, i-1))
428+
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
421429
end
422430
#CUDA.cuLaunchKernel(f,
423431
# blockdim.x, blockdim.y, blockdim.z,

0 commit comments

Comments
 (0)