Skip to content

Commit d82b31a

Browse files
authored
test: unbreak CUDA CI (EnzymeAD#337)
* chore: add compat entries * feat: add copyto! for ConcreteRArray * feat: `@jit` compile ConcreteRArray broadcasting * fix: manually zero out the lower triangular and upper triangular values * fix: only do it in tests * feat: compile mapreduce for ConcreteRArray * test: manual array conversion * revert: change in Ops.cholesky * revert: remove unnecessary changes * fix: only compile non-CPU broadcasting * fix: address reviewer comments * chore: apply suggestions from code review
1 parent 66d6cfc commit d82b31a

File tree

6 files changed

+88
-22
lines changed

6 files changed

+88
-22
lines changed

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
- name: "Run CompatHelper"
3939
run: |
4040
import CompatHelper
41-
CompatHelper.main()
41+
CompatHelper.main(; subdirs=[".", "test", "lib/ReactantCore"])
4242
shell: julia --color=yes {0}
4343
env:
4444
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ ArrayInterface = "7.10"
4242
CEnum = "0.4, 0.5"
4343
Downloads = "1.6"
4444
Enzyme = "0.13.21"
45-
EnzymeCore = "0.8.6, 0.8.7, 0.8.8"
45+
EnzymeCore = "0.8.8"
4646
GPUArraysCore = "0.1.6, 0.2"
4747
LinearAlgebra = "1.10"
48-
NNlib = "0.9.24"
48+
NNlib = "0.9.26"
4949
OrderedCollections = "1"
5050
Preferences = "1.4"
5151
ReactantCore = "0.1.2"

src/ConcreteRArray.jl

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
215215
end
216216

217217
XLA.await(a.data)
218-
if XLA.BufferOnCPU(a.data.buffer)
218+
if buffer_on_cpu(a)
219219
buf = a.data.buffer
220220
GC.@preserve buf begin
221221
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
@@ -246,7 +246,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
246246
end
247247

248248
XLA.await(a.data)
249-
if XLA.BufferOnCPU(a.data.buffer)
249+
if buffer_on_cpu(a)
250250
buf = a.data.buffer
251251
GC.@preserve buf begin
252252
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
@@ -289,15 +289,52 @@ end
289289

290290
# TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`)
291291
function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteRArray}})
292-
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
293-
if !Base.isconcretetype(ElType)
294-
throw(
295-
ErrorException(
296-
"`copy` on `ConcreteRArray` for non-concrete eltype is not implemented"
297-
),
298-
)
292+
for x in bc.args
293+
x isa ConcreteRArray && XLA.await(x.data)
299294
end
300295

301-
aux = copyto!(similar(Array{ElType}, axes(bc)), bc)
302-
return ConcreteRArray(aux)
296+
all_on_cpu = all(buffer_on_cpu, bc.args)
297+
if all_on_cpu
298+
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
299+
if !Base.isconcretetype(ElType)
300+
throw(
301+
ErrorException(
302+
"`copy` on `ConcreteRArray` for non-concrete eltype is not implemented"
303+
),
304+
)
305+
end
306+
aux = copyto!(similar(Array{ElType}, axes(bc)), bc)
307+
return ConcreteRArray(aux)
308+
end
309+
310+
fn = Reactant.compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,))
311+
return fn(bc.args...)
312+
end
313+
314+
function Base.copyto!(dest::ConcreteRArray, src::ConcreteRArray)
315+
dest.data = src.data
316+
return dest
317+
end
318+
319+
function Base.mapreduce(
320+
@nospecialize(f),
321+
@nospecialize(op),
322+
@nospecialize(A::ConcreteRArray{T,N});
323+
dims=:,
324+
init=nothing,
325+
) where {T,N}
326+
fn = Reactant.compile(CallMapReduce(f, op, dims, init), (A,))
327+
return fn(A)
328+
end
329+
330+
struct CallMapReduce{Fn,Op,Dims,Init}
331+
f::Fn
332+
op::Op
333+
dims::Dims
334+
init::Init
303335
end
336+
337+
(f::CallMapReduce)(A) = Base.mapreduce(f.f, f.op, A; f.dims, f.init)
338+
339+
buffer_on_cpu(::Any) = true
340+
buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer)

src/TracedRNumber.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ for (jlop, hloop) in (
240240
(:(Base.FastMath.exp_fast), :exponential),
241241
(:(Base.log), :log),
242242
(:(Base.sqrt), :sqrt),
243+
(:(Base.ceil), :ceil),
244+
(:(Base.floor), :floor),
243245
)
244246
@eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T}
245247
OutTy = $(hloop === :abs) ? real(T) : T

test/Project.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,24 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1919
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2020
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2121
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
22+
23+
[compat]
24+
ArrayInterface = "7.10"
25+
BenchmarkTools = "1.5"
26+
Enzyme = "0.13.21"
27+
FFTW = "1.8"
28+
Flux = "0.15"
29+
Functors = "0.5"
30+
InteractiveUtils = "1.10"
31+
LinearAlgebra = "1.10"
32+
Lux = "1.4.1"
33+
LuxLib = "1.3"
34+
MLUtils = "0.4.4"
35+
NNlib = "0.9.26"
36+
OneHotArrays = "0.2.6"
37+
Optimisers = "0.4"
38+
Random = "1.10"
39+
SafeTestsets = "0.1"
40+
SpecialFunctions = "2.4"
41+
Statistics = "1.10"
42+
Test = "1.10"

test/ops.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,18 @@ end
8282
end
8383

8484
@testset "cholesky" begin
85-
g(x) = Ops.cholesky(x; lower=true)
85+
# cholesky in stablehlo for the other triangle is implementation defined.
86+
# See https://github.com/EnzymeAD/Reactant.jl/issues/338 for more details.
87+
g1(x) = triu(Ops.cholesky(x))
88+
g2(x) = tril(Ops.cholesky(x; lower=true))
89+
8690
x = ConcreteRArray([
8791
10.0 2.0 3.0
8892
2.0 5.0 6.0
8993
3.0 6.0 9.0
9094
])
91-
@test cholesky(Array(x)).U @jit Ops.cholesky(x)
92-
@test transpose(cholesky(Array(x)).U) @jit g(x)
95+
@test cholesky(Array(x)).U @jit g1(x)
96+
@test transpose(cholesky(Array(x)).U) @jit g2(x)
9397

9498
x = ConcreteRArray(
9599
[
@@ -98,8 +102,9 @@ end
98102
3.0+4.0im 3.0+2.0im 9.0+0.0im
99103
],
100104
)
101-
@test cholesky(Array(x)).U @jit Ops.cholesky(x)
102-
@test adjoint(cholesky(Array(x)).U) @jit g(x)
105+
106+
@test cholesky(Array(x)).U @jit g1(x)
107+
@test adjoint(cholesky(Array(x)).U) @jit g2(x)
103108
end
104109

105110
@testset "clamp" begin
@@ -210,13 +215,14 @@ end
210215
]
211216
# NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation
212217
@test sum(a .* b) @jit f1(a, b)
213-
@test kron(reshape(a, length(a), 1), reshape(b, 1, length(b))) @jit fouter(a, b)
218+
@test kron(reshape(Array(a), length(a), 1), reshape(Array(b), 1, length(b)))
219+
@jit fouter(a, b)
214220
@test a .* b @jit fouter_batch1(a, b)
215221
end
216222

217223
a = ConcreteRArray([1 2; 3 4])
218224
b = ConcreteRArray([5 6; -7 -8])
219-
@test a' * b == @jit f1(a, b)
225+
@test Array(a)' * Array(b) == @jit f1(a, b)
220226
end
221227

222228
@testset "einsum" begin
@@ -239,7 +245,7 @@ end
239245
x = reshape(a, (2, 2))
240246
y = reshape(b, (2, 2))
241247
@test x .* y @jit f3(x, y)
242-
@test x * y @jit f4(x, y)
248+
@test Array(x) * Array(y) @jit f4(x, y)
243249
end
244250
end
245251

0 commit comments

Comments
 (0)