Skip to content

Commit 5f5c73e

Browse files
Fix adjoint test to use @static if instead of return
The previous approach using `return` didn't work because @testitem isn't a function scope. Using @static if prevents the package imports from being attempted at all on Julia 1.12+. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 2f5c85e commit 5f5c73e

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

test/adjoint_tests.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,33 @@
11
@testitem "Adjoint Tests" tags = [:nopre] begin
22
# Skip adjoint tests on Julia 1.12+ due to Enzyme/SciMLSensitivity compatibility issues
3-
# To re-enable: change condition to `true` or `VERSION < v"1.13"`
4-
if VERSION >= v"1.12"
5-
@info "Skipping adjoint tests on Julia $(VERSION) - Enzyme/SciMLSensitivity not compatible with 1.12+"
6-
return
7-
end
8-
9-
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme, Mooncake
3+
# To re-enable: change condition to `false` or `VERSION >= v"1.13"`
4+
@static if VERSION < v"1.12"
5+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme, Mooncake
106

11-
ff(u, p) = u .^ 2 .- p
7+
ff(u, p) = u .^ 2 .- p
128

13-
function solve_nlprob(p)
14-
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
15-
sol = solve(prob, NewtonRaphson())
16-
res = sol isa AbstractArray ? sol : sol.u
17-
return sum(abs2, res)
18-
end
9+
function solve_nlprob(p)
10+
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
11+
sol = solve(prob, NewtonRaphson())
12+
res = sol isa AbstractArray ? sol : sol.u
13+
return sum(abs2, res)
14+
end
1915

20-
p = [3.0, 2.0]
16+
p = [3.0, 2.0]
2117

22-
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
23-
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
24-
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
25-
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
26-
∂p_enzyme = Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), solve_nlprob, p)[1]
18+
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
19+
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
20+
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
21+
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
22+
∂p_enzyme = Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), solve_nlprob, p)[1]
2723

28-
cache = Mooncake.prepare_gradient_cache(solve_nlprob, p)
29-
∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2]
24+
cache = Mooncake.prepare_gradient_cache(solve_nlprob, p)
25+
∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2]
3026

31-
@test ∂p_zygote ∂p_tracker ∂p_reversediff ∂p_enzyme
32-
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff ∂p_enzyme
33-
@test_broken ∂p_forwarddiff ∂p_mooncake
27+
@test ∂p_zygote ∂p_tracker ∂p_reversediff ∂p_enzyme
28+
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff ∂p_enzyme
29+
@test_broken ∂p_forwarddiff ∂p_mooncake
30+
else
31+
@info "Skipping adjoint tests on Julia $(VERSION) - Enzyme/SciMLSensitivity not compatible with 1.12+"
32+
end
3433
end

0 commit comments

Comments
 (0)