@@ -99,12 +99,14 @@ function loss(newp)
99
99
newprob = remake(prob, p = newp)
100
100
sol = solve(newprob, saveat = 1)
101
101
loss = sum(abs2, sol .- xy_data)
102
- return loss, sol
102
+ return loss
103
103
end
104
104
105
105
# Define a callback function to monitor optimization progress
106
- function callback(p , l, sol )
106
+ function callback(state , l)
107
107
display(l)
108
+ newprob = remake(prob, p = state.u)
109
+ sol = solve(newprob, saveat = 1)
108
110
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
109
111
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
110
112
display(plt)
@@ -278,37 +280,28 @@ function loss(newp)
278
280
newprob = remake(prob, p = newp)
279
281
sol = solve(newprob, saveat = 1)
280
282
l = sum(abs2, sol .- xy_data)
281
- return l, sol
283
+ return l
282
284
end
283
285
```
284
286
285
- Notice that our loss function returns the loss value as the first return,
286
- but returns extra information (the ODE solution with the new parameters)
287
- as an extra return argument.
288
- We will explain why this extra return information is helpful in the next section.
289
-
290
287
### Step 5: Solve the Optimization Problem
291
288
292
289
This step will look very similar to [ the first optimization tutorial] (@ref first_opt),
293
- except now we have a new loss function ` loss ` which returns both the loss value
294
- and the associated ODE solution.
295
- (In the previous tutorial, ` L ` only returned the loss value.)
296
290
The ` Optimization.solve ` function can accept an optional callback function
297
291
to monitor the optimization process using extra arguments returned from ` loss ` .
298
292
299
293
The callback syntax is always:
300
294
301
295
```
302
296
callback(
303
- optimization variables ,
297
+ state ,
304
298
the current loss value,
305
- other arguments returned from the loss function, ...
306
299
)
307
300
```
308
301
309
- In this case, we will provide the callback the arguments ` (p , l, sol ) ` ,
310
- since it always takes the current state of the optimization first (` p ` )
311
- then the returns from the loss function (` l, sol ` ).
302
+ In this case, we will provide the callback the arguments ` (state , l) ` ,
303
+ since it always takes the current state of the optimization first (` state ` )
304
+ then the current loss value (` l ` ).
312
305
The return value of the callback function should default to ` false ` .
313
306
` Optimization.solve ` will halt if/when the callback function returns ` true ` instead.
314
307
Typically the ` return ` statement would monitor the loss value
@@ -318,8 +311,10 @@ More details about callbacks in Optimization.jl can be found
318
311
[ here] ( https://docs.sciml.ai/Optimization/stable/API/solve/ ) .
319
312
320
313
``` @example odefit
321
- function callback(p, l, sol )
314
+ function callback(p, l)
322
315
display(l)
316
+ newprob = remake(prob, p = p)
317
+ sol = solve(newprob, saveat = 1)
323
318
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
324
319
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
325
320
display(plt)
0 commit comments