@@ -5,20 +5,21 @@ using Reactant:
55 Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
66using ReactantCore: @trace
77
8+ using Adapt
9+
10+ function Adapt. adapt_storage (:: CUDA.KernelAdaptor , xs:: TracedRArray{T,N} ) where {T,N}
11+ CuDeviceArray {T,N,CUDA.AS.Global} (pointer (xs. mlir_data. value), size (xs))
12+ end
813
914const _kernel_instances = Dict {Any, Any} ()
1015
1116function recufunction (f:: F , tt:: TT = Tuple{}; kwargs... ) where {F,TT}
1217 cuda = CUDA. active_state ()
1318
14- F2 = Reactant. traced_type (F, (), Val (Reactant. TracedToConcrete))
15- tt2 = Reactant. traced_type (tt, (), Val (Reactant. TracedToConcrete))
16-
17-
1819 Base. @lock CUDA. cufunction_lock begin
1920 # compile the function
2021 cache = CUDA. compiler_cache (cuda. context)
21- source = CUDA. methodinstance (F2, tt2 )
22+ source = CUDA. methodinstance (F, tt )
2223 config = CUDA. compiler_config (cuda. device; kwargs... ):: CUDA.CUDACompilerConfig
2324 fun = CUDA. GPUCompiler. cached_compilation (cache, source, config, CUDA. compile, CUDA. link)
2425
0 commit comments