@@ -7,11 +7,202 @@ using ReactantCore: @trace
77
88using Adapt
99
10- # function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
11- # res = CuDeviceArray{T,N,CUDA.AS.Global}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, xs.mlir_data.value.ptr), size(xs))
12- # @show res, xs
13- # return res
14- # end
10+ struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
11+ ptr:: Core.LLVMPtr{T,A}
12+ end
13+
14+
15+ Base. show (io:: IO , a:: AT ) where AT <: CuTracedArray =
16+ CUDA. Printf. @printf (io, " %s cu traced array at %p" , join (size (a), ' ×' ), Int (pointer (a)))
17+
18+ # # array interface
19+
20+ Base. elsize (:: Type{<:CuTracedArray{T}} ) where {T} = sizeof (T)
21+ Base. size (g:: CuTracedArray{T,N,A,Size} ) where {T,N,A,Size} = Size
22+ Base. sizeof (x:: CuTracedArray ) = Base. elsize (x) * length (x)
23+ Base. pointer (x:: CuTracedArray{T,<:Any,A} ) where {T,A} = Base. unsafe_convert (Core. LLVMPtr{T,A}, x)
24+ @inline function Base. pointer (x:: CuTracedArray{T,<:Any,A} , i:: Integer ) where {T,A}
25+ Base. unsafe_convert (Core. LLVMPtr{T,A}, x) + Base. _memory_offset (x, i)
26+ end
27+
28+
29+ # # conversions
30+
31+ Base. unsafe_convert (:: Type{Core.LLVMPtr{T,A}} , x:: CuTracedArray{T,<:Any,A} ) where {T,A} =
32+ x. ptr
33+
34+
35+ # # indexing intrinsics
36+
37+ CUDA. @device_function @inline function arrayref (A:: CuTracedArray{T} , index:: Integer ) where {T}
38+ @boundscheck checkbounds (A, index)
39+ if Base. isbitsunion (T)
40+ arrayref_union (A, index)
41+ else
42+ arrayref_bits (A, index)
43+ end
44+ end
45+
46+ @inline function arrayref_bits (A:: CuTracedArray{T} , index:: Integer ) where {T}
47+ unsafe_load (pointer (A), index)
48+ end
49+
50+ @inline @generated function arrayref_union (A:: CuTracedArray{T,<:Any,AS} , index:: Integer ) where {T,AS}
51+ typs = Base. uniontypes (T)
52+
53+ # generate code that conditionally loads a value based on the selector value.
54+ # lacking noreturn, we return T to avoid inference thinking this can return Nothing.
55+ ex = :(Base. llvmcall (" unreachable" , $ T, Tuple{}))
56+ for (sel, typ) in Iterators. reverse (enumerate (typs))
57+ ex = quote
58+ if selector == $ (sel- 1 )
59+ ptr = reinterpret (Core. LLVMPtr{$ typ,AS}, data_ptr)
60+ unsafe_load (ptr, 1 )
61+ else
62+ $ ex
63+ end
64+ end
65+ end
66+
67+ quote
68+ selector_ptr = typetagdata (A, index)
69+ selector = unsafe_load (selector_ptr)
70+
71+ data_ptr = pointer (A, index)
72+
73+ return $ ex
74+ end
75+ end
76+
77+ CUDA. @device_function @inline function arrayset (A:: CuTracedArray{T} , x:: T , index:: Integer ) where {T}
78+ @boundscheck checkbounds (A, index)
79+ if Base. isbitsunion (T)
80+ arrayset_union (A, x, index)
81+ else
82+ arrayset_bits (A, x, index)
83+ end
84+ return A
85+ end
86+
87+ @inline function arrayset_bits (A:: CuTracedArray{T} , x:: T , index:: Integer ) where {T}
88+ unsafe_store! (pointer (A), x, index)
89+ end
90+
91+ @inline @generated function arrayset_union (A:: CuTracedArray{T,<:Any,AS} , x:: T , index:: Integer ) where {T,AS}
92+ typs = Base. uniontypes (T)
93+ sel = findfirst (isequal (x), typs)
94+
95+ quote
96+ selector_ptr = typetagdata (A, index)
97+ unsafe_store! (selector_ptr, $ (UInt8 (sel- 1 )))
98+
99+ data_ptr = pointer (A, index)
100+
101+ unsafe_store! (reinterpret (Core. LLVMPtr{$ x,AS}, data_ptr), x, 1 )
102+ return
103+ end
104+ end
105+
106+ CUDA. @device_function @inline function const_arrayref (A:: CuTracedArray{T} , index:: Integer ) where {T}
107+ @boundscheck checkbounds (A, index)
108+ unsafe_cached_load (pointer (A), index)
109+ end
110+
111+
112+ # # indexing
113+
114+ Base. IndexStyle (:: Type{<:CuTracedArray} ) = Base. IndexLinear ()
115+
116+ Base. @propagate_inbounds Base. getindex (A:: CuTracedArray{T} , i1:: Integer ) where {T} =
117+ arrayref (A, i1)
118+ Base. @propagate_inbounds Base. setindex! (A:: CuTracedArray{T} , x, i1:: Integer ) where {T} =
119+ arrayset (A, convert (T,x):: T , i1)
120+
121+ # preserve the specific integer type when indexing device arrays,
122+ # to avoid extending 32-bit hardware indices to 64-bit.
123+ Base. to_index (:: CuTracedArray , i:: Integer ) = i
124+
125+ # Base doesn't like Integer indices, so we need our own ND get and setindex! routines.
126+ # See also: https://github.com/JuliaLang/julia/pull/42289
127+ Base. @propagate_inbounds Base. getindex (A:: CuTracedArray ,
128+ I:: Union{Integer, CartesianIndex} ...) =
129+ A[Base. _to_linear_index (A, to_indices (A, I)... )]
130+ Base. @propagate_inbounds Base. setindex! (A:: CuTracedArray , x,
131+ I:: Union{Integer, CartesianIndex} ...) =
132+ A[Base. _to_linear_index (A, to_indices (A, I)... )] = x
133+
134+
135+ # # const indexing
136+
137+ """
138+ Const(A::CuTracedArray)
139+
140+ Mark a CuTracedArray as constant/read-only. The invariant guaranteed is that you will not
141+ modify an CuTracedArray for the duration of the current kernel.
142+
143+ This API can only be used on devices with compute capability 3.5 or higher.
144+
145+ !!! warning
146+ Experimental API. Subject to change without deprecation.
147+ """
148+ struct Const{T,N,AS} <: DenseArray{T,N}
149+ a:: CuTracedArray{T,N,AS}
150+ end
151+ Base. Experimental. Const (A:: CuTracedArray ) = Const (A)
152+
153+ Base. IndexStyle (:: Type{<:Const} ) = IndexLinear ()
154+ Base. size (C:: Const ) = size (C. a)
155+ Base. axes (C:: Const ) = axes (C. a)
156+ Base. @propagate_inbounds Base. getindex (A:: Const , i1:: Integer ) = const_arrayref (A. a, i1)
157+
158+ # deprecated
159+ Base. @propagate_inbounds ldg (A:: CuTracedArray , i1:: Integer ) = const_arrayref (A, i1)
160+
161+
162+ # # other
163+
164+ @inline function Base. iterate (A:: CuTracedArray , i= 1 )
165+ if (i % UInt) - 1 < length (A)
166+ (@inbounds A[i], i + 1 )
167+ else
168+ nothing
169+ end
170+ end
171+
172+ function Base. reinterpret (:: Type{T} , a:: CuTracedArray{S,N,A} ) where {T,S,N,A}
173+ err = GPUArrays. _reinterpret_exception (T, a)
174+ err === nothing || throw (err)
175+
176+ if sizeof (T) == sizeof (S) # fast case
177+ return CuTracedArray {T,N,A} (reinterpret (Core. LLVMPtr{T,A}, a. ptr), size (a), a. maxsize)
178+ end
179+
180+ isize = size (a)
181+ size1 = div (isize[1 ]* sizeof (S), sizeof (T))
182+ osize = tuple (size1, Base. tail (isize)... )
183+ return CuTracedArray {T,N,A} (reinterpret (Core. LLVMPtr{T,A}, a. ptr), osize, a. maxsize)
184+ end
185+
186+
187+ # # reshape
188+
189+ function Base. reshape (a:: CuTracedArray{T,M,A} , dims:: NTuple{N,Int} ) where {T,N,M,A}
190+ if prod (dims) != length (a)
191+ throw (DimensionMismatch (" new dimensions (argument `dims`) must be consistent with array size (`size(a)`)" ))
192+ end
193+ if N == M && dims == size (a)
194+ return a
195+ end
196+ _derived_array (a, T, dims)
197+ end
198+
199+
200+
201+ function Adapt. adapt_storage (:: CUDA.KernelAdaptor , xs:: TracedRArray{T,N} ) where {T,N}
202+ res = CuTracedArray {T,N,CUDA.AS.Global, size(xs)} (Base. reinterpret (Core. LLVMPtr{T,CUDA. AS. Global}, Base. pointer_from_objref (xs)))
203+ @show res, xs
204+ return res
205+ end
15206
16207const _kernel_instances = Dict {Any, Any} ()
17208
@@ -24,6 +215,8 @@ function compile(job)
24215 asm, meta = CUDA. GPUCompiler. compile (:asm , job)
25216 mod = meta. ir
26217 modstr = string (mod)
218+ @show mod
219+ @show modstr
27220 # check if we'll need the device runtime
28221 undefined_fs = filter (collect (functions (meta. ir))) do f
29222 isdeclaration (f) && ! CUDA. LLVM. isintrinsic (f)
@@ -208,7 +401,9 @@ function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuD
208401 aliases = MLIR. API. MlirAttribute[]
209402 for (i, a) in enumerate (args)
210403 @show a
211- arg = nothing
404+ @assert a isa CuDeviceArray
405+ ta = Base. pointer_to_objref (a. ptr):: TracedRArray
406+ arg = ta. mlir_data
212407 arg = Reactant. Compiler. transpose_val (arg)
213408 push! (restys, MLIR. IR. Type (arg))
214409 push! (aliases,
0 commit comments