Skip to content

Commit b25de84

Browse files
khannayKevin Hannay
andauthored
updated minibatch example to use Flux call (#490)
Co-authored-by: Kevin Hannay <[email protected]>
1 parent a812927 commit b25de84

File tree

1 file changed

+108
-66
lines changed

1 file changed

+108
-66
lines changed

docs/src/examples/minibatch.md

Lines changed: 108 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
```julia
44
using DifferentialEquations, Flux, Optim, DiffEqFlux, Plots
5+
using IterTools: ncycle
6+
57

68
function newtons_cooling(du, u, p, t)
79
temp = u[1]
@@ -13,68 +15,94 @@ function true_sol(du, u, p, t)
1315
true_p = [log(2)/8.0, 100.0]
1416
newtons_cooling(du, u, true_p, t)
1517
end
16-
18+
19+
20+
ann = FastChain(FastDense(1,8,tanh), FastDense(8,1,tanh))
21+
θ = initial_params(ann)
22+
1723
function dudt_(u,p,t)
1824
ann(u, p).* u
1925
end
2026

21-
function predict_adjoint(fullp, time_batch)
27+
function predict_adjoint(time_batch)
2228
Array(concrete_solve(prob, Tsit5(),
23-
u0, fullp, saveat = time_batch))
29+
u0, θ, saveat = time_batch))
2430
end
2531

26-
function loss_adjoint(fullp, batch, time_batch)
27-
pred = predict_adjoint(fullp,time_batch)
28-
sum(abs2, batch - pred), pred
32+
function loss_adjoint(batch, time_batch)
33+
pred = predict_adjoint(time_batch)
34+
sum(abs2, batch - pred)#, pred
2935
end
3036

31-
cb = function (p,l,pred;doplot=false) #callback function to observe training
32-
display(l)
33-
# plot current prediction against data
34-
if doplot
35-
pl = scatter(t,ode_data[1,:],label="data")
36-
scatter!(pl,t,pred[1,:],label="prediction")
37-
display(plot(pl))
38-
end
39-
return false
40-
end
4137

4238
u0 = Float32[200.0]
4339
datasize = 30
44-
tspan = (0.0f0, 1.5f0)
40+
tspan = (0.0f0, 3.0f0)
4541

4642
t = range(tspan[1], tspan[2], length=datasize)
4743
true_prob = ODEProblem(true_sol, u0, tspan)
4844
ode_data = Array(solve(true_prob, Tsit5(), saveat=t))
4945

50-
ann = FastChain(FastDense(1,8,tanh), FastDense(8,1,tanh))
51-
pp = initial_params(ann)
52-
prob = ODEProblem{false}(dudt_, u0, tspan, pp)
53-
46+
prob = ODEProblem{false}(dudt_, u0, tspan, θ)
5447

5548
k = 10
5649
train_loader = Flux.Data.DataLoader(ode_data, t, batchsize = k)
5750

51+
for (x, y) in train_loader
52+
@show x
53+
@show y
54+
end
55+
5856
numEpochs = 300
57+
losses=[]
58+
cb() = begin
59+
l=loss_adjoint(ode_data, t)
60+
push!(losses, l)
61+
@show l
62+
pred=predict_adjoint(t)
63+
pl = scatter(t,ode_data[1,:],label="data", color=:black, ylim=(150,200))
64+
scatter!(pl,t,pred[1,:],label="prediction", color=:darkgreen)
65+
display(plot(pl))
66+
false
67+
end
68+
69+
opt=ADAM(0.05)
70+
Flux.train!(loss_adjoint, Flux.params(θ), ncycle(train_loader,numEpochs), opt, cb=Flux.throttle(cb, 10))
71+
72+
#Now lets see how well it generalizes to new initial conditions
73+
74+
starting_temp=collect(10:30:250)
75+
true_prob_func(u0)=ODEProblem(true_sol, [u0], tspan)
76+
color_cycle=palette(:tab10)
77+
pl=plot()
78+
for (j,temp) in enumerate(starting_temp)
79+
ode_test_sol = solve(ODEProblem(true_sol, [temp], (0.0f0,10.0f0)), Tsit5(), saveat=0.0:0.5:10.0)
80+
ode_nn_sol = solve(ODEProblem{false}(dudt_, [temp], (0.0f0,10.0f0), θ))
81+
scatter!(pl, ode_test_sol, var=(0,1), label="", color=color_cycle[j])
82+
plot!(pl, ode_nn_sol, var=(0,1), label="", color=color_cycle[j], lw=2.0)
83+
end
84+
display(pl)
85+
title!("Neural ODE for Newton's Law of Cooling: Test Data")
86+
xlabel!("Time")
87+
ylabel!("Temp")
5988

60-
using IterTools: ncycle
61-
res1 = DiffEqFlux.sciml_train(loss_adjoint, pp, ADAM(0.05), ncycle(train_loader, numEpochs), cb = cb, maxiters = numEpochs)
62-
cb(res1.minimizer,loss_adjoint(res1.minimizer, ode_data, t)...;doplot=true)
6389

64-
```
90+
# How to use MLDataUtils
91+
using MLDataUtils
92+
train_loader, _, _ = kfolds((ode_data, t))
6593

94+
@info "Now training using the MLDataUtils format"
95+
Flux.train!(loss_adjoint, Flux.params(θ), ncycle(eachbatch(train_loader[1], k), numEpochs), opt, cb=Flux.throttle(cb, 10))
96+
```
6697

6798
When training a neural network we need to find the gradient with respect to our data set. There are three main ways to partition our data when using a training algorithm like gradient descent: stochastic, batching and mini-batching. Stochastic gradient descent trains on a single random data point each epoch. This allows for the neural network to better converge to the global minimum even on noisy data but is computationally inefficient. Batch gradient descent trains on the whole data set each epoch and while computationally effiecient is prone to converging to local minima. Mini-batching combines both of these advantages and by training on a small random "mini-batch" of the data each epoch can converge to the global minimum while remaining more computationally effiecient than stochastic descent. Typically we do this by randomly selecting subsets of the data each epoch and use this subset to train on. We can also pre-batch the data by creating an iterator holding these randomly selected batches before beginning to train. The proper size for the batch can be determined expirementally. Let us see how to do this with Julia.
6899

69-
70-
71-
72100
For this example we will use a very simple ordinary differential equation, newtons law of cooling. We can represent this in Julia like so.
73101

74-
75-
76102
```julia
77103
using DifferentialEquations, Flux, Optim, DiffEqFlux, Plots
104+
using IterTools: ncycle
105+
78106

79107
function newtons_cooling(du, u, p, t)
80108
temp = u[1]
@@ -86,54 +114,52 @@ function true_sol(du, u, p, t)
86114
true_p = [log(2)/8.0, 100.0]
87115
newtons_cooling(du, u, true_p, t)
88116
end
89-
90-
u0 = Float32[200.0]
91-
datasize = 30
92-
tspan = (0.0f0, 1.5f0)
93-
94-
t = range(tspan[1], tspan[2], length=datasize)
95-
true_prob = ODEProblem(true_sol, u0, tspan)
96-
ode_data = Array(solve(true_prob, Tsit5(), saveat=t))
97-
98117
```
99118

100119
Now we define a neural-network using a linear approximation with 1 hidden layer of 8 neurons.
101120

102121
```julia
103122
ann = FastChain(FastDense(1,8,tanh), FastDense(8,1,tanh))
104-
pp = initial_params(ann)
105-
prob = ODEProblem{false}(dudt_, u0, tspan, pp)
123+
θ = initial_params(ann)
106124

107125
function dudt_(u,p,t)
108126
ann(u, p).* u
109127
end
110128
```
111129

112-
113130
From here we build a loss function around it.
114131

115132
```julia
116-
function predict_adjoint(fullp, time_batch)
133+
function predict_adjoint(time_batch)
117134
Array(concrete_solve(prob, Tsit5(),
118-
u0, fullp, saveat = time_batch))
135+
u0, θ, saveat = time_batch))
119136
end
120137

121-
function loss_adjoint(fullp, batch, time_batch)
122-
pred = predict_adjoint(fullp,time_batch)
123-
sum(abs2, batch - pred), pred
138+
function loss_adjoint(batch, time_batch)
139+
pred = predict_adjoint(time_batch)
140+
sum(abs2, batch - pred)#, pred
124141
end
125142
```
126143

127144
To add support for batches of size `k` we use `Flux.Data.DataLoader`. To use this we pass in the `ode_data` and `t` as the 'x' and 'y' data to batch respectively. The parameter `batchsize` controls the size of our batches. We check our implementation by iterating over the batched data.
128145

129146
```julia
147+
u0 = Float32[200.0]
148+
datasize = 30
149+
tspan = (0.0f0, 3.0f0)
150+
151+
t = range(tspan[1], tspan[2], length=datasize)
152+
true_prob = ODEProblem(true_sol, u0, tspan)
153+
ode_data = Array(solve(true_prob, Tsit5(), saveat=t))
154+
prob = ODEProblem{false}(dudt_, u0, tspan, θ)
155+
130156
k = 10
131157
train_loader = Flux.Data.DataLoader(ode_data, t, batchsize = k)
132158
for (x, y) in train_loader
133159
@show x
134160
@show y
135161
end
136-
162+
137163

138164
#x = Float32[200.0 199.55284 199.1077 198.66454 198.22334 197.78413 197.3469 196.9116 196.47826 196.04686]
139165
#y = Float32[0.0, 0.05172414, 0.10344828, 0.15517241, 0.20689656, 0.25862068, 0.31034482, 0.36206895, 0.41379312, 0.46551725]
@@ -143,26 +169,43 @@ end
143169
#y = Float32[1.0344827, 1.0862069, 1.137931, 1.1896552, 1.2413793, 1.2931035, 1.3448275, 1.3965517, 1.4482758, 1.5]
144170
```
145171

172+
Now we train the neural network with a user defined call back function to display loss and the graphs with a maximum of 300 epochs.
146173

174+
```julia
175+
numEpochs = 300
176+
losses=[]
177+
cb() = begin
178+
l=loss_adjoint(ode_data, t)
179+
push!(losses, l)
180+
@show l
181+
pred=predict_adjoint(t)
182+
pl = scatter(t,ode_data[1,:],label="data", color=:black, ylim=(150,200))
183+
scatter!(pl,t,pred[1,:],label="prediction", color=:darkgreen)
184+
display(plot(pl))
185+
false
186+
end
187+
188+
opt=ADAM(0.05)
189+
Flux.train!(loss_adjoint, Flux.params(θ), ncycle(train_loader,numEpochs), opt, cb=Flux.throttle(cb, 10))
190+
```
147191

192+
Finally we can see how well our trained network will generalize to new initial conditions.
148193

149-
Now we train the neural network with a user defined call back function to display loss and the graphs with a maximum of 300 epochs.
150194
```julia
151-
numEpochs = 300
152-
cb = function (p,l,pred;doplot=false) #callback function to observe training
153-
display(l)
154-
# plot current prediction against data
155-
if doplot
156-
pl = scatter(t,ode_data[1,:],label="data")
157-
scatter!(pl,t,pred[1,:],label="prediction")
158-
display(plot(pl))
159-
end
160-
return false
195+
starting_temp=collect(10:30:250)
196+
true_prob_func(u0)=ODEProblem(true_sol, [u0], tspan)
197+
color_cycle=palette(:tab10)
198+
pl=plot()
199+
for (j,temp) in enumerate(starting_temp)
200+
ode_test_sol = solve(ODEProblem(true_sol, [temp], (0.0f0,10.0f0)), Tsit5(), saveat=0.0:0.5:10.0)
201+
ode_nn_sol = solve(ODEProblem{false}(dudt_, [temp], (0.0f0,10.0f0), θ))
202+
scatter!(pl, ode_test_sol, var=(0,1), label="", color=color_cycle[j])
203+
plot!(pl, ode_nn_sol, var=(0,1), label="", color=color_cycle[j], lw=2.0)
161204
end
162-
163-
using IterTools: ncycle
164-
res1 = DiffEqFlux.sciml_train(loss_adjoint, pp, ADAM(0.05), ncycle(train_loader, numEpochs), cb = cb, maxiters = numEpochs)
165-
cb(res1.minimizer,loss_adjoint(res1.minimizer, ode_data, t)...;doplot=true)
205+
display(pl)
206+
title!("Neural ODE for Newton's Law of Cooling: Test Data")
207+
xlabel!("Time")
208+
ylabel!("Temp")
166209
```
167210

168211
We can also minibatch using tools from `MLDataUtils`. To do this we need to slightly change our implementation and is shown below again with a batch size of k and the same number of epochs.
@@ -171,7 +214,6 @@ We can also minibatch using tools from `MLDataUtils`. To do this we need to slig
171214
using MLDataUtils
172215
train_loader, _, _ = kfolds((ode_data, t))
173216

174-
res1 = DiffEqFlux.sciml_train(loss_adjoint, pp, ADAM(0.05), ncycle(eachbatch(train_loader[1], k), numEpochs), cb = cb, maxiters = numEpochs)
175-
cb(res1.minimizer,loss_adjoint(res1.minimizer, ode_data, t)...;doplot=true)
176-
217+
@info "Now training using the MLDataUtils format"
218+
Flux.train!(loss_adjoint, Flux.params(θ), ncycle(eachbatch(train_loader[1], k), numEpochs), opt, cb=Flux.throttle(cb, 10))
177219
```

0 commit comments

Comments
 (0)