Skip to content

Commit 4eb4454

Browse files
cleanup Reactant and Enzyme tests (#2578)
1 parent 1ec93e9 commit 4eb4454

File tree

6 files changed

+52
-30
lines changed

6 files changed

+52
-30
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1818
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1919
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
2020
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21-
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
2221
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2322
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2423
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -61,7 +60,6 @@ OneHotArrays = "0.2.4"
6160
Optimisers = "0.4.1"
6261
Preferences = "1"
6362
ProgressLogging = "0.1"
64-
Reactant = "0.2.16"
6563
Reexport = "1.0"
6664
Setfield = "1.1"
6765
SpecialFunctions = "2.1.2"

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1414
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1515
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1616
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
17-
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1817
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1918
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2019
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/ext_enzyme/enzyme.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@ using Enzyme: Enzyme, Duplicated, Const, Active
1212
(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
1313
(Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
1414
(Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
15-
# (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
15+
(Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
1616
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
1717
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
18-
# (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # Passes on 1.10, fails on 1.11 with MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
18+
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
1919
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
2020
(first LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
21-
# (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # AssertionError: Base.isconcretetype(typ)
22-
# (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # AssertionError: Base.isconcretetype(typ)
21+
(BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
22+
(first MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"),
2323
]
2424

2525
for (model, x, name) in models_xs
2626
@testset "Enzyme grad check $name" begin
2727
println("testing $name with Enzyme")
28-
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
28+
test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
2929
end
3030
end
3131
end
@@ -36,17 +36,17 @@ end
3636
end
3737

3838
models_xs = [
39-
# (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
40-
# (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
41-
# (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
42-
# (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
43-
# (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
39+
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
40+
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
41+
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
42+
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
43+
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
4444
]
4545

4646
for (model, x, name) in models_xs
4747
@testset "check grad $name" begin
4848
println("testing $name")
49-
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
49+
test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
5050
end
5151
end
5252
end
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# These are used only in test_utils.jl but cannot leave there
2+
# because Reactant is only optionally loaded and the macros fail when it is not loaded.
3+
4+
function reactant_withgradient(f, x...)
5+
y, g = Reactant.@jit enzyme_withgradient(f, x...)
6+
return y, g
7+
end
8+
9+
function reactant_loss(loss, x...)
10+
l = Reactant.@jit loss(x...)
11+
@test l isa Reactant.ConcreteRNumber
12+
return l
13+
end

test/runtests.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ using Pkg
1313
using FiniteDifferences: FiniteDifferences
1414
using Functors: fmapstructure_with_path
1515

16-
using Reactant
17-
1816
## Uncomment below to change the default test settings
1917
# ENV["FLUX_TEST_AMDGPU"] = "true"
2018
# ENV["FLUX_TEST_CUDA"] = "true"
@@ -23,20 +21,20 @@ using Reactant
2321
# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true"
2422
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
2523
# ENV["FLUX_TEST_ENZYME"] = "false"
24+
# ENV["FLUX_TEST_REACTANT"] = "false"
2625

2726
const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
28-
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true"
27+
28+
# Reactant will automatically select a GPU backend, if available, and TPU backend, if available.
29+
# Otherwise it will fall back to CPU.
30+
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT",
31+
VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true"
2932

3033
if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT
3134
Pkg.add("Enzyme")
3235
using Enzyme: Enzyme
3336
end
3437

35-
if FLUX_TEST_REACTANT
36-
Pkg.add("Reactant")
37-
using Reactant: Reactant
38-
end
39-
4038
include("test_utils.jl") # for test_gradients
4139

4240
Random.seed!(0)
@@ -182,7 +180,15 @@ end
182180
end
183181

184182
if FLUX_TEST_REACTANT
183+
## This Pg.add has to be done after Pkg.add("CUDA") otherwise CUDA.jl
184+
## will not be functional and complain with:
185+
# ┌ Error: CUDA.jl could not find an appropriate CUDA runtime to use.
186+
#
187+
# │ CUDA.jl's JLLs were precompiled without an NVIDIA driver present.
188+
Pkg.add("Reactant")
189+
using Reactant: Reactant
185190
@testset "Reactant" begin
191+
include("ext_reactant/test_utils_reactant.jl")
186192
include("ext_reactant/reactant.jl")
187193
end
188194
else

test/test_utils.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ end
3737

3838
function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
3939
fmapstructure_with_path(a, b) do kp, x, y
40+
# @show kp
4041
if x isa AbstractArray
4142
@test x y rtol=rtol atol=atol
4243
elseif x isa Number
@@ -45,23 +46,29 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
4546
end
4647
end
4748

49+
# By default, this computes the gradients on cpu using the default AD (Zygote)
50+
# and compares them with finite differences.
51+
# Changing the arguments, you can assume the cpu Zygote gradients as the ground truth
52+
# and test other scenarios.
4853
function test_gradients(
4954
f,
5055
xs...;
5156
rtol=1e-4, atol=1e-4,
5257
test_gpu = false,
5358
test_reactant = false,
59+
test_enzyme = false,
5460
test_grad_f = true,
5561
test_grad_x = true,
5662
compare_finite_diff = true,
57-
compare_enzyme = false,
5863
loss = (f, xs...) -> mean(f(xs...)),
5964
)
6065

61-
if !test_gpu && !compare_finite_diff && !compare_enzyme && !test_reactant
66+
if !test_gpu && !compare_finite_diff && !test_enzyme && !test_reactant
6267
error("You should either compare numerical gradients methods or CPU vs GPU.")
6368
end
6469

70+
Flux.trainmode!(f) # for layers like BatchNorm
71+
6572
## Let's make sure first that the forward pass works.
6673
l = loss(f, xs...)
6774
@test l isa Number
@@ -79,8 +86,7 @@ function test_gradients(
7986
cpu_dev = cpu_device()
8087
xs_re = xs |> reactant_dev
8188
f_re = f |> reactant_dev
82-
l_re = Reactant.@jit loss(f_re, xs_re...)
83-
@test l_re isa Reactant.ConcreteRNumber
89+
l_re = reactant_loss(loss, f_re, xs_re...)
8490
@test l l_re rtol=rtol atol=atol
8591
end
8692

@@ -97,7 +103,7 @@ function test_gradients(
97103
check_equal_leaves(g, g_fd; rtol, atol)
98104
end
99105

100-
if compare_enzyme
106+
if test_enzyme
101107
y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...)
102108
@test y y_ez rtol=rtol atol=atol
103109
check_equal_leaves(g, g_ez; rtol, atol)
@@ -113,7 +119,7 @@ function test_gradients(
113119

114120
if test_reactant
115121
# Enzyme gradient with respect to input on Reactant.
116-
y_re, g_re = Reactant.@jit enzyme_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
122+
y_re, g_re = reactant_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
117123
@test y y_re rtol=rtol atol=atol
118124
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
119125
end
@@ -133,7 +139,7 @@ function test_gradients(
133139
check_equal_leaves(g, g_fd; rtol, atol)
134140
end
135141

136-
if compare_enzyme
142+
if test_enzyme
137143
y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f)
138144
@test y y_ez rtol=rtol atol=atol
139145
check_equal_leaves(g, g_ez; rtol, atol)
@@ -149,7 +155,7 @@ function test_gradients(
149155

150156
if test_reactant
151157
# Enzyme gradient with respect to input on Reactant.
152-
y_re, g_re = Reactant.@jit enzyme_withgradient(f -> loss(f, xs_re...), f_re)
158+
y_re, g_re = reactant_withgradient(f -> loss(f, xs_re...), f_re)
153159
@test y y_re rtol=rtol atol=atol
154160
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
155161
end

0 commit comments

Comments
 (0)