@@ -9,6 +9,12 @@ using Adapt
99
1010struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
1111 ptr:: Core.LLVMPtr{T,A}
12+
13+ function CuTracedArray {T,N,A,Size} (xs:: TracedRArray ) where {T,N,A,Size}
14+ push! (Reactant. Compiler. context_gc_vector[MLIR. IR. context ()], xs)
15+ ptr = Base. reinterpret (Core. LLVMPtr{T,CUDA. AS. Global}, Base. pointer_from_objref (xs))
16+ return new (ptr)
17+ end
1218end
1319
1420function Base. show (io:: IO , a:: AT ) where {AT<: CuTracedArray }
@@ -211,10 +217,34 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M
211217 return _derived_array (a, T, dims)
212218end
213219
214- function Adapt. adapt_storage (:: CUDA.KernelAdaptor , xs:: TracedRArray{T,N} ) where {T,N}
215- res = CuTracedArray {T,N,CUDA.AS.Global,size(xs)} (
216- Base. reinterpret (Core. LLVMPtr{T,CUDA. AS. Global}, Base. pointer_from_objref (xs))
220+ struct ReactantKernelAdaptor end
221+
222+ function Adapt. adapt_storage (to:: ReactantKernelAdaptor , p:: CUDA.CuPtr )
223+ return error (" Cannot convert CuPtr argument of Reactant Kernel" )
224+ end
225+ function Adapt. adapt_storage (ka:: ReactantKernelAdaptor , xs:: DenseCuArray )
226+ return Adapt. adapt_storage (ka, Array (xs))
227+ end
228+ function Adapt. adapt_storage (ka:: ReactantKernelAdaptor , xs:: Array )
229+ return Adapt. adapt_storage (ka, Reactant. Ops. constant (xs))
230+ end
231+ function Adapt. adapt_structure (to:: ReactantKernelAdaptor , ref:: Base.RefValue )
232+ return error (" Cannot convert RefValue argument of Reactant Kernel" )
233+ end
234+ function Adapt. adapt_structure (
235+ to:: ReactantKernelAdaptor , bc:: Broadcast.Broadcasted{Style,<:Any,Type{T}}
236+ ) where {Style,T}
237+ return Broadcast. Broadcasted {Style} (
238+ (x... ) -> T (x... ), Adapt. adapt (to, bc. args), bc. axes
217239 )
240+ end
241+
242+ Reactant. @reactant_overlay @noinline function CUDA. cudaconvert (arg)
243+ return adapt (ReactantKernelAdaptor (), arg)
244+ end
245+
246+ function Adapt. adapt_storage (:: ReactantKernelAdaptor , xs:: TracedRArray{T,N} ) where {T,N}
247+ res = CuTracedArray {T,N,CUDA.AS.Global,size(xs)} (xs)
218248 return res
219249end
220250
383413function Reactant. make_tracer (
384414 seen, @nospecialize (prev:: CuTracedArray ), @nospecialize (path), mode; kwargs...
385415)
386- x = Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, prev. ptr)):: TracedRArray
416+ x = Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, prev. ptr))
417+ x = x:: TracedRArray
387418 Reactant. make_tracer (seen, x, path, mode; kwargs... )
388419 return prev
389420end
@@ -441,12 +472,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
441472
442473 # linearize kernel arguments
443474 seen = Reactant. OrderedIdDict ()
444- prev = Any[func. f, args... ]
445475 kernelargsym = gensym (" kernelarg" )
446- Reactant. make_tracer (seen, prev, (kernelargsym,), Reactant. TracedTrack)
447- @show prev
448- @show Core. Typeof (prev)
449- @show seen
476+ for (i, prev) in enumerate (Any[func. f, args... ])
477+ Reactant. make_tracer (seen, prev, (kernelargsym, i), Reactant. NoStopTracedTrack)
478+ end
450479 wrapper_tys = MLIR. IR. Type[]
451480 for arg in values (seen)
452481 if ! (arg isa TracedRArray || arg isa TracedRNumber)
@@ -539,16 +568,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
539568 if ! (arg isa TracedRArray || arg isa TracedRNumber)
540569 continue
541570 end
542- for p in Reactant. TracedUtils. get_paths (arg)
571+
572+ paths = Reactant. TracedUtils. get_paths (arg)
573+
574+ arg = arg. mlir_data
575+ arg = Reactant. TracedUtils. transpose_val (arg)
576+ push! (restys, MLIR. IR. type (arg))
577+ push! (mlir_args, arg)
578+
579+ for p in paths
543580 if p[1 ] != = kernelargsym
544581 continue
545582 end
546-
547- arg = arg. mlir_data
548- arg = Reactant. TracedUtils. transpose_val (arg)
549- push! (restys, MLIR. IR. type (arg))
550- push! (mlir_args, arg)
551-
552583 # Get the allocation corresponding to which arg we're doing
553584 alloc = allocs[p[2 ]][1 ]
554585
@@ -583,9 +614,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
583614 ),
584615 ),
585616 )
586-
587- argidx += 1
588617 end
618+ argidx += 1
589619 end
590620
591621 MLIR. IR. block! (wrapbody) do
0 commit comments