@@ -40,7 +40,6 @@ using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
4040 OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
4141
4242rng = Random.default_rng()
43- Random.seed!(rng, 0)
4443tspan = (0.0f0, 8.0f0)
4544
4645ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
@@ -72,7 +71,7 @@ function loss_adjoint(θ)
7271end
7372
7473l = loss_adjoint(θ)
75- cb = function (state, l; doplot = false )
74+ cb = function (state, l; doplot = true )
7675 println(l)
7776
7877 ps = ComponentArray(state.u, ax)
9089# Setup and run the optimization
9190
9291loss1 = loss_adjoint(θ)
93- adtype = Optimization.AutoZygote ()
92+ adtype = Optimization.AutoForward ()
9493optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
9594
9695optprob = Optimization.OptimizationProblem(optf, θ)
9796res1 = Optimization.solve(
98- optprob, OptimizationOptimisers.Adam(0.01), callback = cb, maxiters = 300)
97+ optprob, OptimizationOptimisers.Adam(0.01), callback = cb, maxiters = 100)
98+
99+ optprob2 = Optimization.OptimizationProblem(optf, res1.u)
100+ res2 = Optimization.solve(
101+ optprob2, OptimizationOptimJL.BFGS(), callback = cb, maxiters = 100)
99102```
100103
101104Now that the system is in a better behaved part of parameter space, we return to
@@ -110,8 +113,8 @@ function loss_adjoint(θ)
110113end
111114optf3 = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
112115
113- optprob3 = Optimization.OptimizationProblem(optf3, res1 .u)
114- res3 = Optimization.solve(optprob3, OptimizationOptimisers.Adam(0.01 ), maxiters = 100)
116+ optprob3 = Optimization.OptimizationProblem(optf3, res2 .u)
117+ res3 = Optimization.solve(optprob3, OptimizationOptimJL.BFGS( ), maxiters = 100)
115118```
116119
117120Now let's see what we received:
0 commit comments