@@ -109,18 +109,16 @@ mean and variance from `n` runs at each time point and uses the distance from
109
109
the data values:
110
110
111
111
``` julia
112
- function predict_neuralsde (p)
113
- return Array (neuralsde (u0 , p))
112
+ function predict_neuralsde (p, u = u0 )
113
+ return Array (neuralsde (u , p))
114
114
end
115
115
116
116
function loss_neuralsde (p; n = 100 )
117
- samples = [predict_neuralsde (p) for i in 1 : n]
118
- means = reshape (mean .([[samples[i][j] for i in 1 : length (samples)]
119
- for j in 1 : length (samples[1 ])]),
120
- size (samples[1 ])... )
121
- vars = reshape (var .([[samples[i][j] for i in 1 : length (samples)]
122
- for j in 1 : length (samples[1 ])]),
123
- size (samples[1 ])... )
117
+ u = repeat (reshape (u0, :, 1 ), 1 , n)
118
+ samples = predict_neuralsde (p, u)
119
+ means = mean (samples, dims = 2 )
120
+ vars = var (samples, dims = 2 , mean = means)[:, 1 , :]
121
+ means = means[:, 1 , :]
124
122
loss = sum (abs2, sde_data - means) + sum (abs2, sde_data_vars - vars)
125
123
return loss, means, vars
126
124
end
@@ -143,9 +141,9 @@ callback = function (p, loss, means, vars; doplot = false)
143
141
display (loss)
144
142
145
143
# plot current prediction against data
146
- plt = scatter (tsteps, sde_data[1 ,:], yerror = sde_data_vars[1 ,:],
147
- ylim = (- 4.0 , 8.0 ), label = " data" )
148
- scatter! (plt, tsteps, means[1 ,:], ribbon = vars[1 ,:], label = " prediction" )
144
+ plt = Plots . scatter (tsteps, sde_data[1 ,:], yerror = sde_data_vars[1 ,:],
145
+ ylim = (- 4.0 , 8.0 ), label = " data" )
146
+ Plots . scatter! (plt, tsteps, means[1 ,:], ribbon = vars[1 ,:], label = " prediction" )
149
147
push! (list_plots, plt)
150
148
151
149
if doplot
@@ -180,17 +178,11 @@ result2 = DiffEqFlux.sciml_train((p) -> loss_neuralsde(p, n = 100),
180
178
And now we plot the solution to an ensemble of the trained neural SDE:
181
179
182
180
``` julia
183
- samples = [predict_neuralsde (result2. minimizer) for i in 1 : 1000 ]
184
- means = reshape (mean .([[samples[i][j] for i in 1 : length (samples)]
185
- for j in 1 : length (samples[1 ])]),
186
- size (samples[1 ])... )
187
- vars = reshape (var .([[samples[i][j] for i in 1 : length (samples)]
188
- for j in 1 : length (samples[1 ])]),
189
- size (samples[1 ])... )
190
-
191
- plt2 = scatter (tsteps, sde_data' , yerror = sde_data_vars' ,
192
- label = " data" , title = " Neural SDE: After Training" ,
193
- xlabel = " Time" )
181
+ _, means, vars = loss_neuralsde (result2. minimizer, n = 1000 )
182
+
183
+ plt2 = Plots. scatter (tsteps, sde_data' , yerror = sde_data_vars' ,
184
+ label = " data" , title = " Neural SDE: After Training" ,
185
+ xlabel = " Time" )
194
186
plot! (plt2, tsteps, means' , lw = 8 , ribbon = vars' , label = " prediction" )
195
187
196
188
plt = plot (plt1, plt2, layout = (2 , 1 ))
0 commit comments