Skip to content

Commit a7bbc65

Browse files
Update sde_neural.jl
1 parent 129e44e commit a7bbc65

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

test/sde_neural.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Random.seed!(238248735)
9999
sensealg = ReverseDiffAdjoint()))
100100
tmp_mean = mean(tmp_sol, dims = 3)[:, :]
101101
tmp_var = var(tmp_sol, dims = 3)[:, :]
102-
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
102+
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var)
103103
end
104104

105105
function loss_op(θ)
@@ -112,13 +112,13 @@ Random.seed!(238248735)
112112
sensealg = ReverseDiffAdjoint()))
113113
tmp_mean = mean(tmp_sol, dims = 3)[:, :]
114114
tmp_var = var(tmp_sol, dims = 3)[:, :]
115-
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
115+
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var)
116116
end
117117

118118
losses = []
119-
function callback(θ, l, pred)
119+
function callback(θ, state)
120120
begin
121-
push!(losses, l)
121+
push!(losses, state.u)
122122
if length(losses) % 50 == 0
123123
println("Current loss after $(length(losses)) iterations: $(losses[end])")
124124
end
@@ -189,14 +189,14 @@ end
189189
sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP()),
190190
saveat = ts, trajectories = 10, abstol = 1e-1, reltol = 1e-1)
191191
A = convert(Array, _sol)
192-
sum(abs2, A .- 1), mean(A)
192+
sum(abs2, A .- 1)
193193
end
194194

195195
# Actually training/fitting the model
196196
losses = []
197-
function callback(θ, l, pred)
197+
function callback(θ, state)
198198
begin
199-
push!(losses, l)
199+
push!(losses, state.u)
200200
if length(losses) % 1 == 0
201201
println("Current loss after $(length(losses)) iterations: $(losses[end])")
202202
end

0 commit comments

Comments
 (0)