Skip to content

Commit 46e8210

Browse files
committed
cuconvert
1 parent 060f245 commit 46e8210

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,12 @@ function Adapt.adapt_structure(
239239
)
240240
end
241241

242-
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
242+
function recudaconvert(arg)
243243
return adapt(ReactantKernelAdaptor(), arg)
244244
end
245+
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
246+
return recudaconvert(arg)
247+
end
245248

246249
function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
247250
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs)
@@ -456,7 +459,7 @@ end
456459

457460
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
458461
args...;
459-
convert=Val(false),
462+
convert=Val(true),
460463
blocks::CuDim=1,
461464
threads::CuDim=1,
462465
cooperative::Bool=false,
@@ -466,6 +469,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
466469
blockdim = CUDA.CuDim3(blocks)
467470
threaddim = CUDA.CuDim3(threads)
468471

472+
if convert == Val(true)
473+
args = recudaconvert.(args)
474+
end
475+
469476
mlir_args = MLIR.IR.Value[]
470477
restys = MLIR.IR.Type[]
471478
aliases = MLIR.IR.Attribute[]

0 commit comments

Comments
 (0)