Skip to content

Commit 7c7c8ed

Browse files
author
William Moses
committed
wip
1 parent b8e3570 commit 7c7c8ed

File tree

1 file changed

+201
-6
lines changed

1 file changed

+201
-6
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 201 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,202 @@ using ReactantCore: @trace
77

88
using Adapt
99

10-
#function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
11-
# res = CuDeviceArray{T,N,CUDA.AS.Global}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, xs.mlir_data.value.ptr), size(xs))
12-
# @show res, xs
13-
# return res
14-
#end
10+
struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
11+
ptr::Core.LLVMPtr{T,A}
12+
end
13+
14+
15+
Base.show(io::IO, a::AT) where AT <: CuTracedArray =
16+
CUDA.Printf.@printf(io, "%s cu traced array at %p", join(size(a), '×'), Int(pointer(a)))
17+
18+
## array interface
19+
20+
Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T)
21+
Base.size(g::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size
22+
Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x)
23+
Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(Core.LLVMPtr{T,A}, x)
24+
@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A}
25+
Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i)
26+
end
27+
28+
29+
## conversions
30+
31+
Base.unsafe_convert(::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A}) where {T,A} =
32+
x.ptr
33+
34+
35+
## indexing intrinsics
36+
37+
CUDA.@device_function @inline function arrayref(A::CuTracedArray{T}, index::Integer) where {T}
38+
@boundscheck checkbounds(A, index)
39+
if Base.isbitsunion(T)
40+
arrayref_union(A, index)
41+
else
42+
arrayref_bits(A, index)
43+
end
44+
end
45+
46+
@inline function arrayref_bits(A::CuTracedArray{T}, index::Integer) where {T}
47+
unsafe_load(pointer(A), index)
48+
end
49+
50+
@inline @generated function arrayref_union(A::CuTracedArray{T,<:Any,AS}, index::Integer) where {T,AS}
51+
typs = Base.uniontypes(T)
52+
53+
# generate code that conditionally loads a value based on the selector value.
54+
# lacking noreturn, we return T to avoid inference thinking this can return Nothing.
55+
ex = :(Base.llvmcall("unreachable", $T, Tuple{}))
56+
for (sel, typ) in Iterators.reverse(enumerate(typs))
57+
ex = quote
58+
if selector == $(sel-1)
59+
ptr = reinterpret(Core.LLVMPtr{$typ,AS}, data_ptr)
60+
unsafe_load(ptr, 1)
61+
else
62+
$ex
63+
end
64+
end
65+
end
66+
67+
quote
68+
selector_ptr = typetagdata(A, index)
69+
selector = unsafe_load(selector_ptr)
70+
71+
data_ptr = pointer(A, index)
72+
73+
return $ex
74+
end
75+
end
76+
77+
CUDA.@device_function @inline function arrayset(A::CuTracedArray{T}, x::T, index::Integer) where {T}
78+
@boundscheck checkbounds(A, index)
79+
if Base.isbitsunion(T)
80+
arrayset_union(A, x, index)
81+
else
82+
arrayset_bits(A, x, index)
83+
end
84+
return A
85+
end
86+
87+
@inline function arrayset_bits(A::CuTracedArray{T}, x::T, index::Integer) where {T}
88+
unsafe_store!(pointer(A), x, index)
89+
end
90+
91+
@inline @generated function arrayset_union(A::CuTracedArray{T,<:Any,AS}, x::T, index::Integer) where {T,AS}
92+
typs = Base.uniontypes(T)
93+
sel = findfirst(isequal(x), typs)
94+
95+
quote
96+
selector_ptr = typetagdata(A, index)
97+
unsafe_store!(selector_ptr, $(UInt8(sel-1)))
98+
99+
data_ptr = pointer(A, index)
100+
101+
unsafe_store!(reinterpret(Core.LLVMPtr{$x,AS}, data_ptr), x, 1)
102+
return
103+
end
104+
end
105+
106+
CUDA.@device_function @inline function const_arrayref(A::CuTracedArray{T}, index::Integer) where {T}
107+
@boundscheck checkbounds(A, index)
108+
unsafe_cached_load(pointer(A), index)
109+
end
110+
111+
112+
## indexing
113+
114+
Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear()
115+
116+
Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} =
117+
arrayref(A, i1)
118+
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} =
119+
arrayset(A, convert(T,x)::T, i1)
120+
121+
# preserve the specific integer type when indexing device arrays,
122+
# to avoid extending 32-bit hardware indices to 64-bit.
123+
Base.to_index(::CuTracedArray, i::Integer) = i
124+
125+
# Base doesn't like Integer indices, so we need our own ND get and setindex! routines.
126+
# See also: https://github.com/JuliaLang/julia/pull/42289
127+
Base.@propagate_inbounds Base.getindex(A::CuTracedArray,
128+
I::Union{Integer, CartesianIndex}...) =
129+
A[Base._to_linear_index(A, to_indices(A, I)...)]
130+
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray, x,
131+
I::Union{Integer, CartesianIndex}...) =
132+
A[Base._to_linear_index(A, to_indices(A, I)...)] = x
133+
134+
135+
## const indexing
136+
137+
"""
138+
Const(A::CuTracedArray)
139+
140+
Mark a CuTracedArray as constant/read-only. The invariant guaranteed is that you will not
141+
modify an CuTracedArray for the duration of the current kernel.
142+
143+
This API can only be used on devices with compute capability 3.5 or higher.
144+
145+
!!! warning
146+
Experimental API. Subject to change without deprecation.
147+
"""
148+
struct Const{T,N,AS} <: DenseArray{T,N}
149+
a::CuTracedArray{T,N,AS}
150+
end
151+
Base.Experimental.Const(A::CuTracedArray) = Const(A)
152+
153+
Base.IndexStyle(::Type{<:Const}) = IndexLinear()
154+
Base.size(C::Const) = size(C.a)
155+
Base.axes(C::Const) = axes(C.a)
156+
Base.@propagate_inbounds Base.getindex(A::Const, i1::Integer) = const_arrayref(A.a, i1)
157+
158+
# deprecated
159+
Base.@propagate_inbounds ldg(A::CuTracedArray, i1::Integer) = const_arrayref(A, i1)
160+
161+
162+
## other
163+
164+
@inline function Base.iterate(A::CuTracedArray, i=1)
165+
if (i % UInt) - 1 < length(A)
166+
(@inbounds A[i], i + 1)
167+
else
168+
nothing
169+
end
170+
end
171+
172+
function Base.reinterpret(::Type{T}, a::CuTracedArray{S,N,A}) where {T,S,N,A}
173+
err = GPUArrays._reinterpret_exception(T, a)
174+
err === nothing || throw(err)
175+
176+
if sizeof(T) == sizeof(S) # fast case
177+
return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), size(a), a.maxsize)
178+
end
179+
180+
isize = size(a)
181+
size1 = div(isize[1]*sizeof(S), sizeof(T))
182+
osize = tuple(size1, Base.tail(isize)...)
183+
return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), osize, a.maxsize)
184+
end
185+
186+
187+
## reshape
188+
189+
function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M,A}
190+
if prod(dims) != length(a)
191+
throw(DimensionMismatch("new dimensions (argument `dims`) must be consistent with array size (`size(a)`)"))
192+
end
193+
if N == M && dims == size(a)
194+
return a
195+
end
196+
_derived_array(a, T, dims)
197+
end
198+
199+
200+
201+
function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
202+
res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)))
203+
@show res, xs
204+
return res
205+
end
15206

16207
const _kernel_instances = Dict{Any, Any}()
17208

@@ -24,6 +215,8 @@ function compile(job)
24215
asm, meta = CUDA.GPUCompiler.compile(:asm, job)
25216
mod = meta.ir
26217
modstr = string(mod)
218+
@show mod
219+
@show modstr
27220
# check if we'll need the device runtime
28221
undefined_fs = filter(collect(functions(meta.ir))) do f
29222
isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
@@ -208,7 +401,9 @@ function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuD
208401
aliases = MLIR.API.MlirAttribute[]
209402
for (i, a) in enumerate(args)
210403
@show a
211-
arg = nothing
404+
@assert a isa CuDeviceArray
405+
ta = Base.pointer_to_objref(a.ptr)::TracedRArray
406+
arg = ta.mlir_data
212407
arg = Reactant.Compiler.transpose_val(arg)
213408
push!(restys, MLIR.IR.Type(arg))
214409
push!(aliases,

0 commit comments

Comments
 (0)