Skip to content

Commit 6cef30a

Browse files
committed
test: add reverse mode AD testing
1 parent c5ae0ce commit 6cef30a

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

test/lotka_volterra.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using SciMLStructures
1111
using SciMLStructures: Tunable, canonicalize
1212
using ForwardDiff
1313
using StableRNGs
14+
using DifferentiationInterface
15+
using SciMLSensitivity
16+
using Zygote: Zygote
1417

1518
function lotka_ude()
1619
@variables t x(t)=3.1 y(t)=1.5
@@ -86,14 +89,22 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
8689
end
8790
end
8891

89-
of = OptimizationFunction{true}(loss, AutoForwardDiff())
92+
of = OptimizationFunction{true}(loss, AutoZygote())
9093

9194
ps = (prob, sol_ref, get_vars, get_refs, set_x);
9295

9396
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9497
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9598

96-
@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0)))
99+
∇l1 = DifferentiationInterface.gradient(Base.Fix2(of, ps), AutoForwardDiff(), x0)
100+
∇l2 = DifferentiationInterface.gradient(Base.Fix2(of, ps), AutoFiniteDiff(), x0)
101+
∇l3 = DifferentiationInterface.gradient(Base.Fix2(of, ps), AutoZygote(), x0)
102+
103+
@test all(.!isnan.(∇l1))
104+
@test !iszero(∇l1)
105+
106+
@test ∇l1∇l2 rtol=1e-2
107+
@test ∇l1∇l3 rtol=1e-5
97108

98109
op = OptimizationProblem(of, x0, ps)
99110

0 commit comments

Comments
 (0)