|
1 |
| -using Test |
2 |
| -using Flux |
3 |
| -import Zygote |
4 |
| - |
5 |
| -using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal |
6 |
| - |
7 |
| -using Functors |
8 |
| -using FiniteDifferences |
9 |
| - |
10 |
| - |
11 |
| -function gradient_fd(f, x...) |
12 |
| - f = f |> f64 |
13 |
| - x = [cpu(x) for x in x] |
14 |
| - ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x] |
15 |
| - ps = [f64(x[1]) for x in ps_and_res] |
16 |
| - res = [x[2] for x in ps_and_res] |
17 |
| - fdm = FiniteDifferences.central_fdm(5, 1) |
18 |
| - gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...) |
19 |
| - return ((re(g) for (re, g) in zip(res, gs))...,) |
20 |
| -end |
21 |
| - |
22 |
| -function gradient_ez(f, x...) |
23 |
| - args = [] |
24 |
| - for x in x |
25 |
| - if x isa Number |
26 |
| - push!(args, Active(x)) |
27 |
| - else |
28 |
| - push!(args, Duplicated(x, make_zero(x))) |
29 |
| - end |
30 |
| - end |
31 |
| - ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) |
32 |
| - g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) |
33 |
| - return g |
34 |
| -end |
35 |
| - |
36 |
| -function test_grad(g1, g2; broken=false) |
37 |
| - fmap_with_path(g1, g2) do kp, x, y |
38 |
| - :state ∈ kp && return # ignore RNN and LSTM state |
39 |
| - if x isa AbstractArray{<:Number} |
40 |
| - # @show kp |
41 |
| - @test x ≈ y rtol=1e-2 atol=1e-6 broken=broken |
42 |
| - end |
43 |
| - return x |
44 |
| - end |
45 |
| -end |
46 |
| - |
47 |
| -function test_enzyme_grad(loss, model, x) |
48 |
| - Flux.trainmode!(model) |
49 |
| - l = loss(model, x) |
50 |
| - @test loss(model, x) == l # Check loss doesn't change with multiple runs |
51 |
| - |
52 |
| - grads_fd = gradient_fd(loss, model, x) |> cpu |
53 |
| - grads_flux = Flux.gradient(loss, model, x) |> cpu |
54 |
| - grads_enzyme = gradient_ez(loss, model, x) |> cpu |
55 |
| - |
56 |
| - # test_grad(grads_flux, grads_enzyme) |
57 |
| - test_grad(grads_fd, grads_enzyme) |
58 |
| -end |
59 |
| - |
60 |
| -@testset "gradient_ez" begin |
61 |
| - @testset "number and arrays" begin |
62 |
| - f(x, y) = sum(x.^2) + y^3 |
63 |
| - x = Float32[1, 2, 3] |
64 |
| - y = 3f0 |
65 |
| - g = gradient_ez(f, x, y) |
66 |
| - @test g[1] isa Array{Float32} |
67 |
| - @test g[2] isa Float32 |
68 |
| - @test g[1] ≈ 2x |
69 |
| - @test g[2] ≈ 3*y^2 |
70 |
| - end |
71 |
| - |
72 |
| - @testset "struct" begin |
73 |
| - struct SimpleDense{W, B, F} |
74 |
| - weight::W |
75 |
| - bias::B |
76 |
| - σ::F |
77 |
| - end |
78 |
| - SimpleDense(in::Integer, out::Integer; σ=identity) = SimpleDense(randn(Float32, out, in), zeros(Float32, out), σ) |
79 |
| - (m::SimpleDense)(x) = m.σ.(m.weight * x .+ m.bias) |
80 |
| - |
81 |
| - model = SimpleDense(2, 4) |
82 |
| - x = randn(Float32, 2) |
83 |
| - loss(model, x) = sum(model(x)) |
84 |
| - |
85 |
| - g = gradient_ez(loss, model, x) |
86 |
| - @test g[1] isa SimpleDense |
87 |
| - @test g[2] isa Array{Float32} |
88 |
| - @test g[1].weight isa Array{Float32} |
89 |
| - @test g[1].bias isa Array{Float32} |
90 |
| - @test g[1].weight ≈ ones(Float32, 4, 1) .* x' |
91 |
| - @test g[1].bias ≈ ones(Float32, 4) |
92 |
| - end |
93 |
| -end |
| 1 | +using Enzyme: Enzyme, Duplicated, Const, Active |
94 | 2 |
|
95 | 3 | @testset "Models" begin
|
96 | 4 | function loss(model, x)
|
97 |
| - sum(model(x)) |
| 5 | + mean(model(x)) |
98 | 6 | end
|
99 | 7 |
|
100 | 8 | models_xs = [
|
|
117 | 25 | for (model, x, name) in models_xs
|
118 | 26 | @testset "Enzyme grad check $name" begin
|
119 | 27 | println("testing $name with Enzyme")
|
120 |
| - test_enzyme_grad(loss, model, x) |
| 28 | + test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true) |
121 | 29 | end
|
122 | 30 | end
|
123 | 31 | end
|
124 | 32 |
|
125 |
| -@testset "Recurrence Tests" begin |
| 33 | +@testset "Recurrent Layers" begin |
126 | 34 | function loss(model, x)
|
127 |
| - for i in 1:3 |
128 |
| - x = model(x) |
129 |
| - end |
130 |
| - return sum(x) |
131 |
| - end |
132 |
| - |
133 |
| - struct LSTMChain |
134 |
| - rnn1 |
135 |
| - rnn2 |
136 |
| - end |
137 |
| - function (m::LSTMChain)(x) |
138 |
| - st = m.rnn1(x) |
139 |
| - st = m.rnn2(st[1]) |
140 |
| - return st[1] |
| 35 | + mean(model(x)) |
141 | 36 | end
|
142 | 37 |
|
143 | 38 | models_xs = [
|
144 | 39 | # (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
|
145 | 40 | # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
|
146 | 41 | # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
|
147 | 42 | # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
|
148 |
| - # (LSTMChain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "LSTMChain(LSTM, LSTM)"), |
| 43 | + # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), |
149 | 44 | ]
|
150 | 45 |
|
151 | 46 | for (model, x, name) in models_xs
|
152 | 47 | @testset "check grad $name" begin
|
153 | 48 | println("testing $name")
|
154 |
| - test_enzyme_grad(loss, model, x) |
| 49 | + test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true) |
155 | 50 | end
|
156 | 51 | end
|
157 | 52 | end
|
|
219 | 114 | z = _duplicated(zeros32(3))
|
220 | 115 | @test_broken Flux.gradient(sum ∘ LayerNorm(3), z)[1] ≈ [0.0, 0.0, 0.0] # Constant memory is stored (or returned) to a differentiable variable
|
221 | 116 | @test Flux.gradient(|>, z, _duplicated(sum ∘ LayerNorm(3)))[1] ≈ [0.0, 0.0, 0.0]
|
222 |
| - @test Flux.gradient(|>, z, Const(sum ∘ LayerNorm(3)))[2] === nothing |
| 117 | + @test Flux.gradient(|>, z, Const(sum ∘ LayerNorm(3)))[2] === nothing broken=VERSION >= v"1.11" |
223 | 118 |
|
224 | 119 | @test_broken Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
|
225 | 120 | @test_broken Flux.withgradient(|>, z, _duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0]
|
|
0 commit comments