@@ -99,7 +99,7 @@ Random.seed!(238248735)
9999 sensealg = ReverseDiffAdjoint ()))
100100 tmp_mean = mean (tmp_sol, dims = 3 )[:, :]
101101 tmp_var = var (tmp_sol, dims = 3 )[:, :]
102- sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var), tmp_mean
102+ sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var)
103103 end
104104
105105 function loss_op (θ)
@@ -112,11 +112,11 @@ Random.seed!(238248735)
112112 sensealg = ReverseDiffAdjoint ()))
113113 tmp_mean = mean (tmp_sol, dims = 3 )[:, :]
114114 tmp_var = var (tmp_sol, dims = 3 )[:, :]
115- sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var), tmp_mean
115+ sum (abs2, truemean - tmp_mean) + 0.1 * sum (abs2, truevar - tmp_var)
116116 end
117117
118118 losses = []
119- function callback (θ , l, pred )
119+ function callback (state , l)
120120 begin
121121 push! (losses, l)
122122 if length (losses) % 50 == 0
@@ -189,12 +189,12 @@ end
189189 sensealg = BacksolveAdjoint (autojacvec = ReverseDiffVJP ()),
190190 saveat = ts, trajectories = 10 , abstol = 1e-1 , reltol = 1e-1 )
191191 A = convert (Array, _sol)
192- sum (abs2, A .- 1 ), mean (A)
192+ sum (abs2, A .- 1 )
193193 end
194194
195195 # Actually training/fitting the model
196196 losses = []
197- function callback (θ , l, pred )
197+ function callback (state , l)
198198 begin
199199 push! (losses, l)
200200 if length (losses) % 1 == 0
0 commit comments