Skip to content

Commit 37b7a08

Browse files
Use other solvers than Tsit5 in multiple_shoot (#521)
* Fix parameter name in docstring * Fix error message should say `< 1` and `> number of data points` instead of `<=` and `>=`. * Relax types * Make code solver agnostic instead of relying on hardcoded `Tsit5()`. * Remove preset grp_size from docstring Preset grp_size might lead to DomainErrors if the dataset has less than 5 data points. * Update tests to solver agnostic interface, add DomainError test * Update docs to use solver agnostic multiple shooting * Update Project.toml * Update Project.toml Co-authored-by: Christopher Rackauckas <[email protected]>
1 parent aa6007c commit 37b7a08

File tree

4 files changed

+28
-13
lines changed

4 files changed

+28
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2828
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2929

3030
[compat]
31-
Adapt = "2, 3.0"
31+
Adapt = "3 - 3.2"
3232
ConsoleProgressMonitor = "0.1"
3333
DataInterpolations = "3.3"
3434
DiffEqBase = "6.41"

docs/src/examples/multiple_shooting.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function loss_function_param(ode_data, pred):: Float32
8787
end
8888

8989
function loss_neuralode(p)
90-
return multiple_shoot(p, ode_data, tsteps, prob_param, loss_function_param, grp_size_param, loss_multiplier_param)
90+
return multiple_shoot(p, ode_data, tsteps, prob_param, loss_function_param, Tsit5(), grp_size_param, loss_multiplier_param)
9191
end
9292

9393
result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p,

src/multiple_shooting.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ The default group size is 5 implying the whole dataset would be divided in group
55
The default continuity term is 100 implying any losses arising from the non-continuity of 2 different groups will be scaled by 100.
66
77
```julia
8-
multiple_shoot(p,ode_data,tsteps,prob,loss_function,grp_size=5,continuity_strength=100)
8+
multiple_shoot(p,ode_data,tsteps,prob,loss_function,grp_size,continuity_strength=100)
99
```
1010
Arguments:
1111
- `p`: The parameters of the Neural Network to be trained.
@@ -14,40 +14,51 @@ Arguments:
1414
- `prob`: ODE problem that the Neural Network attempts to solve.
1515
- `loss_function`: Any arbitrary function to calculate loss.
1616
- `grp_size`: The group size achieved after splitting the ode_data into equal sizes.
17-
- `continuity_strength`: Multiplying factor to ensure continuity of predictions throughout different groups.
17+
- `continuity_term`: Multiplying factor to ensure continuity of predictions throughout different groups.
1818
1919
!!!note
20-
The parameter 'continuity_strength' should be a relatively big number to enforce a large penalty whenever the last point of any group doesn't coincide with the first point of next group.
20+
The parameter 'continuity_term' should be a relatively big number to enforce a large penalty whenever the last point of any group doesn't coincide with the first point of next group.
2121
"""
22-
function multiple_shoot(p :: Array, ode_data :: Array, tsteps, prob :: ODEProblem, loss_function ::Function, grp_size :: Integer = 5, continuity_term :: Integer = 100)
22+
function multiple_shoot(
23+
p::AbstractArray,
24+
ode_data::AbstractArray,
25+
tsteps,
26+
prob::ODEProblem,
27+
loss_function::Function,
28+
solver::DiffEqBase.AbstractODEAlgorithm,
29+
grp_size::Integer,
30+
continuity_term::Real=100
31+
)
2332
datasize = length(ode_data[1,:])
2433

25-
@assert (grp_size >= 1 && grp_size <= datasize) "grp_size can't be <= 1 or >= number of data points"
34+
if grp_size < 1 || grp_size > datasize
35+
throw(DomainError(grp_size, "grp_size can't be < 1 or > number of data points"))
36+
end
2637

2738
tot_loss = 0
2839

2940
if(grp_size == datasize)
30-
grp_predictions = [solve(remake(prob, p = p, tspan = (tsteps[1],tsteps[datasize]), u0 = ode_data[:,1]), Tsit5(),saveat = tsteps)]
41+
grp_predictions = [solve(remake(prob, p=p, tspan=(tsteps[1],tsteps[datasize]), u0=ode_data[:,1]), solver, saveat=tsteps)]
3142
tot_loss = loss_function(ode_data, grp_predictions[1][:,1:grp_size])
3243
return tot_loss, grp_predictions[1], grp_predictions
3344
end
3445

3546
if(grp_size == 1)
3647
# store individual single shooting predictions for each group
37-
grp_predictions = [solve(remake(prob, p = p, tspan = (tsteps[i],tsteps[i+1]), u0 = ode_data[:,i]), Tsit5(),saveat = tsteps[i:i+1]) for i in 1:datasize-1]
48+
grp_predictions = [solve(remake(prob, p=p, tspan=(tsteps[i],tsteps[i+1]), u0=ode_data[:,i]), solver, saveat=tsteps[i:i+1]) for i in 1:datasize-1]
3849

3950
# calculate multiple shooting loss from the single shoots done in above step
4051
for i in 1:datasize-1
4152
tot_loss += loss_function(ode_data[:,i:i+1], grp_predictions[i][:, 1:grp_size]) + (continuity_term * sum(abs,grp_predictions[i][:,2] - ode_data[:,i+1]))
4253
end
4354

4455
# single shooting predictions from ode_data[:,1] (= u0)
45-
pred = solve(remake(prob, p = p, tspan = (tsteps[1],tsteps[datasize]), u0 = ode_data[:,1]), Tsit5(),saveat = tsteps)
56+
pred = solve(remake(prob, p=p, tspan=(tsteps[1],tsteps[datasize]), u0=ode_data[:,1]), solver, saveat=tsteps)
4657
return tot_loss, pred, grp_predictions
4758
end
4859

4960
# multiple shooting predictions
50-
grp_predictions = [solve(remake(prob, p = p, tspan = (tsteps[i],tsteps[i+grp_size-1]), u0 = ode_data[:,i]), Tsit5(),saveat = tsteps[i:i+grp_size-1]) for i in 1:grp_size-1:datasize-grp_size]
61+
grp_predictions = [solve(remake(prob, p=p, tspan=(tsteps[i],tsteps[i+grp_size-1]), u0=ode_data[:,i]), solver, saveat=tsteps[i:i+grp_size-1]) for i in 1:grp_size-1:datasize-grp_size]
5162

5263
# calculate multiple shooting loss
5364
for i in 1:grp_size-1:datasize-grp_size
@@ -56,6 +67,6 @@ function multiple_shoot(p :: Array, ode_data :: Array, tsteps, prob :: ODEProble
5667
end
5768

5869
# single shooting prediction
59-
pred = solve(remake(prob, p = p, tspan = (tsteps[1],tsteps[datasize]), u0 = ode_data[:,1]), Tsit5(),saveat = tsteps)
70+
pred = solve(remake(prob, p=p, tspan=(tsteps[1],tsteps[datasize]), u0=ode_data[:,1]), solver, saveat=tsteps)
6071
return tot_loss, pred, grp_predictions
6172
end

test/multiple_shoot.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function loss_function_param(ode_data, pred):: Float32
5151
end
5252

5353
function loss_neuralode_param(p)
54-
return multiple_shoot(p, ode_data, tsteps, prob_param, loss_function_param, grp_size_param, loss_multiplier_param)
54+
return multiple_shoot(p, ode_data, tsteps, prob_param, loss_function_param, Tsit5(), grp_size_param, loss_multiplier_param)
5555
end
5656

5757

@@ -76,3 +76,7 @@ multiple_shoot_loss_2 = general_loss_function(multiple_shoot_result_neuralode_2)
7676
println("multiple_shoot_loss_2: ",multiple_shoot_loss_2)
7777

7878
@test multiple_shoot_loss_2 < single_shoot_loss
79+
80+
# test for DomainErrors
81+
@test_throws DomainError multiple_shoot(prob_neuralode.p, ode_data, tsteps, prob_param, loss_function_param, Tsit5(), 0, loss_multiplier_param)
82+
@test_throws DomainError multiple_shoot(prob_neuralode.p, ode_data, tsteps, prob_param, loss_function_param, Tsit5(), datasize + 1, loss_multiplier_param)

0 commit comments

Comments
 (0)