Skip to content

Commit cb9bb30

Browse files
fix test enzyme (#2563)
* fix test enzyme * fix * don't test enzyme on nightly * remove enzyme from test project * update
1 parent b205266 commit cb9bb30

File tree

5 files changed

+53
-133
lines changed

5 files changed

+53
-133
lines changed

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
33
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
5-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
65
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
76
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
87
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -22,7 +21,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2221
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2322

2423
[compat]
25-
Enzyme = "0.13"
2624
FiniteDifferences = "0.12"
2725
GPUArraysCore = "0.1"
2826
GPUCompiler = "0.27"

test/ext_enzyme/enzyme.jl

Lines changed: 8 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,8 @@
1-
using Test
2-
using Flux
3-
import Zygote
4-
5-
using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal
6-
7-
using Functors
8-
using FiniteDifferences
9-
10-
11-
function gradient_fd(f, x...)
12-
f = f |> f64
13-
x = [cpu(x) for x in x]
14-
ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x]
15-
ps = [f64(x[1]) for x in ps_and_res]
16-
res = [x[2] for x in ps_and_res]
17-
fdm = FiniteDifferences.central_fdm(5, 1)
18-
gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...)
19-
return ((re(g) for (re, g) in zip(res, gs))...,)
20-
end
21-
22-
function gradient_ez(f, x...)
23-
args = []
24-
for x in x
25-
if x isa Number
26-
push!(args, Active(x))
27-
else
28-
push!(args, Duplicated(x, make_zero(x)))
29-
end
30-
end
31-
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
32-
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
33-
return g
34-
end
35-
36-
function test_grad(g1, g2; broken=false)
37-
fmap_with_path(g1, g2) do kp, x, y
38-
:state kp && return # ignore RNN and LSTM state
39-
if x isa AbstractArray{<:Number}
40-
# @show kp
41-
@test x y rtol=1e-2 atol=1e-6 broken=broken
42-
end
43-
return x
44-
end
45-
end
46-
47-
function test_enzyme_grad(loss, model, x)
48-
Flux.trainmode!(model)
49-
l = loss(model, x)
50-
@test loss(model, x) == l # Check loss doesn't change with multiple runs
51-
52-
grads_fd = gradient_fd(loss, model, x) |> cpu
53-
grads_flux = Flux.gradient(loss, model, x) |> cpu
54-
grads_enzyme = gradient_ez(loss, model, x) |> cpu
55-
56-
# test_grad(grads_flux, grads_enzyme)
57-
test_grad(grads_fd, grads_enzyme)
58-
end
59-
60-
@testset "gradient_ez" begin
61-
@testset "number and arrays" begin
62-
f(x, y) = sum(x.^2) + y^3
63-
x = Float32[1, 2, 3]
64-
y = 3f0
65-
g = gradient_ez(f, x, y)
66-
@test g[1] isa Array{Float32}
67-
@test g[2] isa Float32
68-
@test g[1] 2x
69-
@test g[2] 3*y^2
70-
end
71-
72-
@testset "struct" begin
73-
struct SimpleDense{W, B, F}
74-
weight::W
75-
bias::B
76-
σ::F
77-
end
78-
SimpleDense(in::Integer, out::Integer; σ=identity) = SimpleDense(randn(Float32, out, in), zeros(Float32, out), σ)
79-
(m::SimpleDense)(x) = m.σ.(m.weight * x .+ m.bias)
80-
81-
model = SimpleDense(2, 4)
82-
x = randn(Float32, 2)
83-
loss(model, x) = sum(model(x))
84-
85-
g = gradient_ez(loss, model, x)
86-
@test g[1] isa SimpleDense
87-
@test g[2] isa Array{Float32}
88-
@test g[1].weight isa Array{Float32}
89-
@test g[1].bias isa Array{Float32}
90-
@test g[1].weight ones(Float32, 4, 1) .* x'
91-
@test g[1].bias ones(Float32, 4)
92-
end
93-
end
1+
using Enzyme: Enzyme, Duplicated, Const, Active
942

953
@testset "Models" begin
964
function loss(model, x)
97-
sum(model(x))
5+
mean(model(x))
986
end
997

1008
models_xs = [
@@ -117,41 +25,28 @@ end
11725
for (model, x, name) in models_xs
11826
@testset "Enzyme grad check $name" begin
11927
println("testing $name with Enzyme")
120-
test_enzyme_grad(loss, model, x)
28+
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
12129
end
12230
end
12331
end
12432

125-
@testset "Recurrence Tests" begin
33+
@testset "Recurrent Layers" begin
12634
function loss(model, x)
127-
for i in 1:3
128-
x = model(x)
129-
end
130-
return sum(x)
131-
end
132-
133-
struct LSTMChain
134-
rnn1
135-
rnn2
136-
end
137-
function (m::LSTMChain)(x)
138-
st = m.rnn1(x)
139-
st = m.rnn2(st[1])
140-
return st[1]
35+
mean(model(x))
14136
end
14237

14338
models_xs = [
14439
# (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
14540
# (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
14641
# (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
14742
# (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
148-
# (LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"),
43+
# (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
14944
]
15045

15146
for (model, x, name) in models_xs
15247
@testset "check grad $name" begin
15348
println("testing $name")
154-
test_enzyme_grad(loss, model, x)
49+
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
15550
end
15651
end
15752
end
@@ -219,7 +114,7 @@ end
219114
z = _duplicated(zeros32(3))
220115
@test_broken Flux.gradient(sum LayerNorm(3), z)[1] [0.0, 0.0, 0.0] # Constant memory is stored (or returned) to a differentiable variable
221116
@test Flux.gradient(|>, z, _duplicated(sum LayerNorm(3)))[1] [0.0, 0.0, 0.0]
222-
@test Flux.gradient(|>, z, Const(sum LayerNorm(3)))[2] === nothing
117+
@test Flux.gradient(|>, z, Const(sum LayerNorm(3)))[2] === nothing broken=VERSION >= v"1.11"
223118

224119
@test_broken Flux.withgradient(sum LayerNorm(3), z).grad[1] [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
225120
@test_broken Flux.withgradient(|>, z, _duplicated(sum LayerNorm(3))).grad[1] [0.0, 0.0, 0.0]

test/runtests.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector
33
using Test
44
using Random, Statistics, LinearAlgebra
55
using IterTools: ncycle
6+
import Optimisers
67

78
using Zygote
89
const gradient = Flux.gradient # both Flux & Zygote export this on 0.15
@@ -21,18 +22,24 @@ using Functors: fmapstructure_with_path
2122
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
2223
# ENV["FLUX_TEST_ENZYME"] = "false"
2324

25+
const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
26+
if FLUX_TEST_ENZYME
27+
Pkg.add("Enzyme")
28+
using Enzyme: Enzyme
29+
end
30+
2431
include("test_utils.jl") # for test_gradients
2532

2633
Random.seed!(0)
2734

2835
include("testsuite/normalization.jl")
2936

3037
function flux_testsuite(dev)
31-
@testset "Flux Test Suite" begin
32-
@testset "Normalization" begin
33-
normalization_testsuite(dev)
34-
end
38+
@testset "Flux Test Suite" begin
39+
@testset "Normalization" begin
40+
normalization_testsuite(dev)
3541
end
42+
end
3643
end
3744

3845
@testset verbose=true "Flux.jl" begin
@@ -157,9 +164,8 @@ end
157164
@info "Skipping Distributed tests, set FLUX_TEST_DISTRIBUTED_MPI or FLUX_TEST_DISTRIBUTED_NCCL=true to run them."
158165
end
159166

160-
if get(ENV, "FLUX_TEST_ENZYME", "true") == "true"
167+
if FLUX_TEST_ENZYME
161168
@testset "Enzyme" begin
162-
import Enzyme
163169
include("ext_enzyme/enzyme.jl")
164170
end
165171
else

test/test_utils.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ function finitediff_withgradient(f, x...)
1919
return y, FiniteDifferences.grad(fdm, f, x...)
2020
end
2121

22+
function enzyme_withgradient(f, x...)
23+
args = []
24+
for x in x
25+
if x isa Number
26+
push!(args, Enzyme.Active(x))
27+
else
28+
push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
29+
end
30+
end
31+
ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
32+
ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
33+
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
34+
return ret[2], g
35+
end
36+
37+
2238
function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
2339
fmapstructure_with_path(a, b) do kp, x, y
2440
if x isa AbstractArray
@@ -37,12 +53,12 @@ function test_gradients(
3753
test_grad_f = true,
3854
test_grad_x = true,
3955
compare_finite_diff = true,
56+
compare_enzyme = false,
4057
loss = (f, xs...) -> mean(f(xs...)),
4158
)
4259

43-
if !test_gpu && !compare_finite_diff
44-
error("You should either compare finite diff vs CPU AD \
45-
or CPU AD vs GPU AD.")
60+
if !test_gpu && !compare_finite_diff && !compare_enzyme
61+
error("You should either compare numerical gradients methods or CPU vs GPU.")
4662
end
4763

4864
## Let's make sure first that the forward pass works.
@@ -70,6 +86,12 @@ function test_gradients(
7086
check_equal_leaves(g, g_fd; rtol, atol)
7187
end
7288

89+
if compare_enzyme
90+
y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...)
91+
@test y y_ez rtol=rtol atol=atol
92+
check_equal_leaves(g, g_ez; rtol, atol)
93+
end
94+
7395
if test_gpu
7496
# Zygote gradient with respect to input on GPU.
7597
y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...)
@@ -93,6 +115,12 @@ function test_gradients(
93115
check_equal_leaves(g, g_fd; rtol, atol)
94116
end
95117

118+
if compare_enzyme
119+
y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f)
120+
@test y y_ez rtol=rtol atol=atol
121+
check_equal_leaves(g, g_ez; rtol, atol)
122+
end
123+
96124
if test_gpu
97125
# Zygote gradient with respect to f on GPU.
98126
y_gpu, g_gpu = Zygote.withgradient(f -> loss(f, xs_gpu...), f_gpu)

test/train.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
1-
using Flux
2-
# using Flux.Train
3-
import Optimisers
4-
5-
using Test
6-
using Random
7-
import Enzyme
81

92
function train_enzyme!(fn, model, args...; kwargs...)
103
Flux.train!(fn, Enzyme.Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
114
end
125

136
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
147

15-
if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false")
8+
if name == "Enzyme" && FLUX_TEST_ENZYME
169
continue
1710
end
1811

@@ -50,7 +43,7 @@ end
5043
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
5144
# TODO reinstate Enzyme
5245
name == "Enzyme" && continue
53-
# if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false")
46+
# if name == "Enzyme" && FLUX_TEST_ENZYME
5447
# continue
5548
# end
5649

0 commit comments

Comments
 (0)