Skip to content

Commit 057744f

Browse files
authored
EnzymeTestUtils: GPUArray support (#2692)
* EnzymeTestUtils: Cuarray support * fix * add file
1 parent 196bde6 commit 057744f

File tree

3 files changed

+130
-4
lines changed

3 files changed

+130
-4
lines changed

lib/EnzymeTestUtils/Project.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeTestUtils"
22
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
33
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.2.5"
4+
version = "0.2.6"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -12,16 +12,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1414

15+
[weakdeps]
16+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
17+
18+
[extensions]
19+
EnzymeTestUtilsGPUArraysCoreExt = ["Enzyme", "GPUArraysCore"]
20+
1521
[compat]
1622
ConstructionBase = "1.4.1"
1723
Enzyme = "0.13.78"
1824
EnzymeCore = "0.5, 0.6, 0.7, 0.8"
1925
FiniteDifferences = "0.12.33"
26+
GPUArraysCore = "0.1.6, 0.2"
2027
MetaTesting = "0.1"
2128
Quaternions = "0.7"
2229
julia = "1.10"
2330

2431
[extras]
32+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2533
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2634
MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56"
2735
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
module EnzymeTestUtilsGPUArraysCoreExt
2+
3+
using GPUArraysCore
4+
using EnzymeTestUtils
5+
using Enzyme
6+
7+
function EnzymeTestUtils.acopyto!(dst, src::AbstractGPUArray)
8+
temp = Array{eltype(src)}(undef, size(src))
9+
Base.copyto!(temp, src)
10+
EnzymeTestUtils.acopyto!(dst, temp)
11+
end
12+
13+
# basic containers: loop over defined elements, recursively converting them to vectors
14+
function EnzymeTestUtils.to_vec(x::AbstractGPUArray{<:EnzymeTestUtils.ElementType}, seen_vecs::EnzymeTestUtils.AliasDict)
15+
has_seen = haskey(seen_vecs, x)
16+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
17+
if has_seen || is_const
18+
x_vec = Float32[]
19+
else
20+
x_vec = reshape(x, length(x))
21+
seen_vecs[x] = x_vec
22+
end
23+
sz = size(x)
24+
function FastGPUArray_from_vec(x_vec_new::AbstractVector{<:EnzymeTestUtils.ElementType}, seen_xs::EnzymeTestUtils.AliasDict)
25+
if xor(has_seen, haskey(seen_xs, x))
26+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
27+
end
28+
has_seen && return reshape(seen_xs[x], size(x))
29+
is_const && return x
30+
x_new = reshape(x_vec_new, sz)
31+
if Core.Typeof(x_new) != Core.Typeof(x)
32+
x_new = Core.Typeof(x)(x_new)
33+
end
34+
seen_xs[x] = x_new
35+
return x_new
36+
end
37+
return x_vec, FastGPUArray_from_vec
38+
end
39+
40+
# basic containers: loop over defined elements, recursively converting them to vectors
41+
function to_vec(x::AbstractGPUArray{<:Complex{<:EnzymeTestUtils.ElementType}}, seen_vecs::EnzymeTestUtils.AliasDict)
42+
has_seen = haskey(seen_vecs, x)
43+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
44+
if has_seen || is_const
45+
x_vec = Float32[]
46+
else
47+
y = reshape(x, length(x))
48+
x_vec = vcat(real.(y), imag.(y))
49+
seen_vecs[x] = x_vec
50+
end
51+
sz = size(x)
52+
function ComplexGPUArray_from_vec(x_vec_new::AbstractVector{<:EnzymeTestUtils.ElementType}, seen_xs::EnzymeTestUtils.AliasDict)
53+
if xor(has_seen, haskey(seen_xs, x))
54+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
55+
end
56+
has_seen && return reshape(seen_xs[x], size(x))
57+
is_const && return x
58+
x_new = Array{eltype(x)}(undef, sz)
59+
@inbounds @simd for i in 1:length(x)
60+
x_new[i] = eltype(x)(x_vec_new[i], x_vec_new[i + length(x)])
61+
end
62+
x_new = Core.Typeof(x)(x_new)
63+
seen_xs[x] = x_new
64+
return x_new
65+
end
66+
return x_vec, ComplexGPUArray_from_vec
67+
end
68+
69+
# basic containers: loop over defined elements, recursively converting them to vectors
70+
function to_vec(x::AbstractGPUArray, seen_vecs::EnzymeTestUtils.AliasDict)
71+
has_seen = haskey(seen_vecs, x)
72+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
73+
if has_seen || is_const
74+
x_vec = Float32[]
75+
else
76+
x_vecs = nothing
77+
from_vecs = []
78+
subvec_inds = UnitRange{Int}[]
79+
l = 0
80+
for i in eachindex(x)
81+
isassigned(x, i) || continue
82+
xi_vec, xi_from_vec = to_vec(x[i], seen_vecs)
83+
push!(subvec_inds, (l + 1):(l + length(xi_vec)))
84+
push!(from_vecs, xi_from_vec)
85+
x_vecs = EnzymeTestUtils.append_or_merge(x_vecs, xi_vec)
86+
l += length(xi_vec)
87+
end
88+
89+
if x_vecs === nothing
90+
x_vecs = (Float32[], true)
91+
end
92+
x_vec = x_vecs[1]
93+
seen_vecs[x] = x_vec
94+
end
95+
function GPUArray_from_vec(x_vec_new::AbstractVector{<:EnzymeTestUtils.ElementType}, seen_xs::EnzymeTestUtils.AliasDict)
96+
if xor(has_seen, haskey(seen_xs, x))
97+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
98+
end
99+
has_seen && return reshape(seen_xs[x], size(x))
100+
is_const && return x
101+
x_new = Array{eltype(x_vew_new)}(undef, size(x))
102+
k = 1
103+
for i in eachindex(x)
104+
isassigned(x, i) || continue
105+
xi = from_vecs[k](@view(x_vec_new[subvec_inds[k]]), seen_xs)
106+
x_new[i] = xi
107+
k += 1
108+
end
109+
x_new = Core.Typeof(x)(x_new)
110+
seen_xs[x] = x_new
111+
return x_new
112+
end
113+
return x_vec, GPUArray_from_vec
114+
end
115+
116+
end # module

lib/EnzymeTestUtils/src/to_vec.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ function to_vec(x::Array{<:ElementType}, seen_vecs::AliasDict)
8686
return x_vec, FastArray_from_vec
8787
end
8888

89+
acopyto!(dst, src) = Base.copyto!(dst, src)
90+
8991
# Returns (vector, bool if new allocation)
90-
function append_or_merge(prev::Union{Nothing, Tuple{Vector, Bool}}, newv::Vector)::Tuple{Vector, Bool}
92+
function append_or_merge(prev::Union{Nothing, Tuple{AbstractVector, Bool}}, newv::AbstractVector)::Tuple{AbstractVector, Bool}
9193
if prev === nothing
9294
return (newv, false)
9395
elseif prev[2] && eltype(newv) <: eltype(prev[1])
@@ -100,8 +102,8 @@ function append_or_merge(prev::Union{Nothing, Tuple{Vector, Bool}}, newv::Vector
100102
return prev
101103
else
102104
res = Vector{ET2}(undef, length(prev[1]) + length(newv))
103-
copyto!(@view(res[1:length(prev[1])]), prev[1])
104-
copyto!(@view(res[length(prev[1])+1:end]), newv)
105+
acopyto!(@view(res[1:length(prev[1])]), prev[1])
106+
acopyto!(@view(res[length(prev[1])+1:end]), newv)
105107
return (res, true)
106108
end
107109
end

0 commit comments

Comments
 (0)