|
1 | 1 | @testitem "Adjoint Tests" tags = [:nopre] begin |
2 | 2 | # Skip adjoint tests on Julia 1.12+ due to Enzyme/SciMLSensitivity compatibility issues |
3 | | - # To re-enable: change condition to `true` or `VERSION < v"1.13"` |
4 | | - if VERSION >= v"1.12" |
5 | | - @info "Skipping adjoint tests on Julia $(VERSION) - Enzyme/SciMLSensitivity not compatible with 1.12+" |
6 | | - return |
7 | | - end |
8 | | - |
9 | | - using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme, Mooncake |
| 3 | + # To re-enable: change condition to `false` or `VERSION >= v"1.13"` |
| 4 | + @static if VERSION < v"1.12" |
| 5 | + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme, Mooncake |
10 | 6 |
|
11 | | - ff(u, p) = u .^ 2 .- p |
| 7 | + ff(u, p) = u .^ 2 .- p |
12 | 8 |
|
13 | | - function solve_nlprob(p) |
14 | | - prob = NonlinearProblem{false}(ff, [1.0, 2.0], p) |
15 | | - sol = solve(prob, NewtonRaphson()) |
16 | | - res = sol isa AbstractArray ? sol : sol.u |
17 | | - return sum(abs2, res) |
18 | | - end |
| 9 | + function solve_nlprob(p) |
| 10 | + prob = NonlinearProblem{false}(ff, [1.0, 2.0], p) |
| 11 | + sol = solve(prob, NewtonRaphson()) |
| 12 | + res = sol isa AbstractArray ? sol : sol.u |
| 13 | + return sum(abs2, res) |
| 14 | + end |
19 | 15 |
|
20 | | - p = [3.0, 2.0] |
| 16 | + p = [3.0, 2.0] |
21 | 17 |
|
22 | | - ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) |
23 | | - ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) |
24 | | - ∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) |
25 | | - ∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) |
26 | | - ∂p_enzyme = Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), solve_nlprob, p)[1] |
| 18 | + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) |
| 19 | + ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) |
| 20 | + ∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) |
| 21 | + ∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) |
| 22 | + ∂p_enzyme = Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), solve_nlprob, p)[1] |
27 | 23 |
|
28 | | - cache = Mooncake.prepare_gradient_cache(solve_nlprob, p) |
29 | | - ∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2] |
| 24 | + cache = Mooncake.prepare_gradient_cache(solve_nlprob, p) |
| 25 | + ∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2] |
30 | 26 |
|
31 | | - @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme |
32 | | - @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme |
33 | | - @test_broken ∂p_forwarddiff ≈ ∂p_mooncake |
| 27 | + @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme |
| 28 | + @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme |
| 29 | + @test_broken ∂p_forwarddiff ≈ ∂p_mooncake |
| 30 | + else |
| 31 | + @info "Skipping adjoint tests on Julia $(VERSION) - Enzyme/SciMLSensitivity not compatible with 1.12+" |
| 32 | + end |
34 | 33 | end |
0 commit comments