Skip to content

Commit ee4937f

Browse files
author
Avik Pal
authored
Compute the different trajectories in parallel (#519)
1 parent 8be3f5a commit ee4937f

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

docs/src/examples/neural_sde.md

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,16 @@ mean and variance from `n` runs at each time point and uses the distance from
109109
the data values:
110110

111111
```julia
112-
function predict_neuralsde(p)
113-
return Array(neuralsde(u0, p))
112+
function predict_neuralsde(p, u = u0)
113+
return Array(neuralsde(u, p))
114114
end
115115

116116
function loss_neuralsde(p; n = 100)
117-
samples = [predict_neuralsde(p) for i in 1:n]
118-
means = reshape(mean.([[samples[i][j] for i in 1:length(samples)]
119-
for j in 1:length(samples[1])]),
120-
size(samples[1])...)
121-
vars = reshape(var.([[samples[i][j] for i in 1:length(samples)]
122-
for j in 1:length(samples[1])]),
123-
size(samples[1])...)
117+
u = repeat(reshape(u0, :, 1), 1, n)
118+
samples = predict_neuralsde(p, u)
119+
means = mean(samples, dims = 2)
120+
vars = var(samples, dims = 2, mean = means)[:, 1, :]
121+
means = means[:, 1, :]
124122
loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars)
125123
return loss, means, vars
126124
end
@@ -143,9 +141,9 @@ callback = function (p, loss, means, vars; doplot = false)
143141
display(loss)
144142

145143
# plot current prediction against data
146-
plt = scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:],
147-
ylim = (-4.0, 8.0), label = "data")
148-
scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction")
144+
plt = Plots.scatter(tsteps, sde_data[1,:], yerror = sde_data_vars[1,:],
145+
ylim = (-4.0, 8.0), label = "data")
146+
Plots.scatter!(plt, tsteps, means[1,:], ribbon = vars[1,:], label = "prediction")
149147
push!(list_plots, plt)
150148

151149
if doplot
@@ -180,17 +178,11 @@ result2 = DiffEqFlux.sciml_train((p) -> loss_neuralsde(p, n = 100),
180178
And now we plot the solution to an ensemble of the trained neural SDE:
181179

182180
```julia
183-
samples = [predict_neuralsde(result2.minimizer) for i in 1:1000]
184-
means = reshape(mean.([[samples[i][j] for i in 1:length(samples)]
185-
for j in 1:length(samples[1])]),
186-
size(samples[1])...)
187-
vars = reshape(var.([[samples[i][j] for i in 1:length(samples)]
188-
for j in 1:length(samples[1])]),
189-
size(samples[1])...)
190-
191-
plt2 = scatter(tsteps, sde_data', yerror = sde_data_vars',
192-
label = "data", title = "Neural SDE: After Training",
193-
xlabel = "Time")
181+
_, means, vars = loss_neuralsde(result2.minimizer, n = 1000)
182+
183+
plt2 = Plots.scatter(tsteps, sde_data', yerror = sde_data_vars',
184+
label = "data", title = "Neural SDE: After Training",
185+
xlabel = "Time")
194186
plot!(plt2, tsteps, means', lw = 8, ribbon = vars', label = "prediction")
195187

196188
plt = plot(plt1, plt2, layout = (2, 1))

0 commit comments

Comments
 (0)