@@ -11,6 +11,9 @@ using SciMLStructures
1111using  SciMLStructures:  Tunable, canonicalize
1212using  ForwardDiff
1313using  StableRNGs
14+ using  DifferentiationInterface
15+ using  SciMLSensitivity
16+ using  Zygote:  Zygote
1417
1518function  lotka_ude ()
1619    @variables  t x (t)= 3.1  y (t)= 1.5 
@@ -59,7 +62,7 @@ prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
5962
6063model_true =  structural_simplify (lotka_true ())
6164prob_true =  ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [])
62- sol_ref =  solve (prob_true, Rodas4 () )
65+ sol_ref =  solve (prob_true, Rodas5P (), abstol  =   1e-10 , reltol  =   1e-8 )
6366
6467x0 =  default_values (sys)[nn. p]
6568
@@ -71,7 +74,7 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
7174    new_p =  set_x (prob, x)
7275    new_prob =  remake (prob, p =  new_p, u0 =  eltype (x).(prob. u0))
7376    ts =  sol_ref. t
74-     new_sol =  solve (new_prob, Rodas4 () , saveat =  ts)
77+     new_sol =  solve (new_prob, Rodas5P (), abstol  =   1e-10 , reltol  =   1e-8 , saveat =  ts)
7578
7679    loss =  zero (eltype (x))
7780
@@ -86,14 +89,22 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
8689    end 
8790end 
8891
89- of =  OptimizationFunction {true} (loss, AutoForwardDiff ())
92+ of =  OptimizationFunction {true} (loss, AutoZygote ())
9093
9194ps =  (prob, sol_ref, get_vars, get_refs, set_x);
9295
9396@test_call  target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
9497@test_opt  target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
9598
96- @test  all (.! isnan .(ForwardDiff. gradient (Base. Fix2 (of, ps), x0)))
99+ ∇l1 =  DifferentiationInterface. gradient (Base. Fix2 (of, ps), AutoForwardDiff (), x0)
100+ ∇l2 =  DifferentiationInterface. gradient (Base. Fix2 (of, ps), AutoFiniteDiff (), x0)
101+ ∇l3 =  DifferentiationInterface. gradient (Base. Fix2 (of, ps), AutoZygote (), x0)
102+ 
103+ @test  all (.! isnan .(∇l1))
104+ @test  ! iszero (∇l1)
105+ 
106+ @test  ∇l1≈ ∇l2 rtol= 1e-2 
107+ @test  ∇l1≈ ∇l3 rtol= 1e-5 
97108
98109op =  OptimizationProblem (of, x0, ps)
99110
@@ -111,7 +122,7 @@ op = OptimizationProblem(of, x0, ps)
111122#      false
112123#  end
113124
114- res =  solve (op, Adam (), maxiters =  5000 )# , callback = plot_cb)
125+ res =  solve (op, Adam (), maxiters =  10000 )# , callback = plot_cb)
115126
116127@test  res. objective <  1 
117128
0 commit comments