Skip to content

Commit 82da2a4

Browse files
author
William Moses
committed
fix
1 parent 66fdf67 commit 82da2a4

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@ using Reactant:
55
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
66
using ReactantCore: @trace
77

8+
using Adapt
9+
10+
function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
11+
CuDeviceArray{T,N,CUDA.AS.Global}(pointer(xs.mlir_data.value), size(xs))
12+
end
813

914
const _kernel_instances = Dict{Any, Any}()
1015

1116
function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
1217
cuda = CUDA.active_state()
1318

14-
F2 = Reactant.traced_type(F, (), Val(Reactant.TracedToConcrete))
15-
tt2 = Reactant.traced_type(tt, (), Val(Reactant.TracedToConcrete))
16-
17-
1819
Base.@lock CUDA.cufunction_lock begin
1920
# compile the function
2021
cache = CUDA.compiler_cache(cuda.context)
21-
source = CUDA.methodinstance(F2, tt2)
22+
source = CUDA.methodinstance(F, tt)
2223
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
2324
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link)
2425

0 commit comments

Comments
 (0)