@@ -99,7 +99,7 @@ Random.seed!(238248735)
9999 sensealg = ReverseDiffAdjoint ()))
100100 tmp_mean = mean (tmp_sol, dims = 3 )[:, :]
101101 tmp_var = var (tmp_sol, dims = 3 )[:, :]
102- sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var), tmp_mean
102+ sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var)
103103 end
104104
105105 function loss_op (θ)
@@ -112,13 +112,13 @@ Random.seed!(238248735)
112112 sensealg = ReverseDiffAdjoint ()))
113113 tmp_mean = mean (tmp_sol, dims = 3 )[:, :]
114114 tmp_var = var (tmp_sol, dims = 3 )[:, :]
115- sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var), tmp_mean
115+ sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var)
116116 end
117117
118118 losses = []
119- function callback (θ, l, pred )
119+ function callback (θ, state )
120120 begin
121- push! (losses, l )
121+ push! (losses, state . u )
122122 if length (losses) % 50 == 0
123123 println (" Current loss after $(length (losses)) iterations: $(losses[end ]) " )
124124 end
@@ -189,14 +189,14 @@ end
189189 sensealg = BacksolveAdjoint (autojacvec = ReverseDiffVJP ()),
190190 saveat = ts, trajectories = 10 , abstol = 1e-1 , reltol = 1e-1 )
191191 A = convert (Array, _sol)
192- sum (abs2, A .- 1 ), mean (A)
192+ sum (abs2, A .- 1 )
193193 end
194194
195195 # Actually training/fitting the model
196196 losses = []
197- function callback (θ, l, pred )
197+ function callback (θ, state )
198198 begin
199- push! (losses, l )
199+ push! (losses, state . u )
200200 if length (losses) % 1 == 0
201201 println (" Current loss after $(length (losses)) iterations: $(losses[end ]) " )
202202 end
0 commit comments