Skip to content

Commit e0cd6d0

Browse files
change optimal control to forward for now
it's an enzyme bug
1 parent c49523b commit e0cd6d0

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

docs/src/examples/optimal_control/optimal_control.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
4040
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
4141
4242
rng = Random.default_rng()
43-
Random.seed!(rng, 0)
4443
tspan = (0.0f0, 8.0f0)
4544
4645
ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
@@ -72,7 +71,7 @@ function loss_adjoint(θ)
7271
end
7372
7473
l = loss_adjoint(θ)
75-
cb = function (state, l; doplot = false)
74+
cb = function (state, l; doplot = true)
7675
println(l)
7776
7877
ps = ComponentArray(state.u, ax)
@@ -90,12 +89,16 @@ end
9089
# Setup and run the optimization
9190
9291
loss1 = loss_adjoint(θ)
93-
adtype = Optimization.AutoZygote()
92+
adtype = Optimization.AutoForward()
9493
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
9594
9695
optprob = Optimization.OptimizationProblem(optf, θ)
9796
res1 = Optimization.solve(
98-
optprob, OptimizationOptimisers.Adam(0.01), callback = cb, maxiters = 300)
97+
optprob, OptimizationOptimisers.Adam(0.01), callback = cb, maxiters = 100)
98+
99+
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
100+
res2 = Optimization.solve(
101+
optprob2, OptimizationOptimJL.BFGS(), callback = cb, maxiters = 100)
99102
```
100103

101104
Now that the system is in a better behaved part of parameter space, we return to
@@ -110,8 +113,8 @@ function loss_adjoint(θ)
110113
end
111114
optf3 = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
112115
113-
optprob3 = Optimization.OptimizationProblem(optf3, res1.u)
114-
res3 = Optimization.solve(optprob3, OptimizationOptimisers.Adam(0.01), maxiters = 100)
116+
optprob3 = Optimization.OptimizationProblem(optf3, res2.u)
117+
res3 = Optimization.solve(optprob3, OptimizationOptimJL.BFGS(), maxiters = 100)
115118
```
116119

117120
Now let's see what we received:

0 commit comments

Comments
 (0)