Skip to content

Commit b714f2d

Browse files
Merge pull request #783 from SciML/checkinit_callbacks
Create CheckInit and add tagging of initializations to callbacks
2 parents 0841a5e + 0e3f267 commit b714f2d

File tree

3 files changed

+90
-33
lines changed

3 files changed

+90
-33
lines changed

src/SciMLBase.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,11 @@ $(TYPEDEF)
343343
"""
344344
struct NoInit <: DAEInitializationAlgorithm end
345345

346+
"""
347+
$(TYPEDEF)
348+
"""
349+
struct CheckInit <: DAEInitializationAlgorithm end
350+
346351
# PDE Discretizations
347352

348353
"""

src/callbacks.jl

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ ContinuousCallback(condition, affect!, affect_neg!;
2222
rootfind = LeftRootFind,
2323
save_positions = (true, true),
2424
interp_points = 10,
25-
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
25+
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
26+
initializealg = nothing)
2627
```
2728
2829
```julia
@@ -34,7 +35,8 @@ ContinuousCallback(condition, affect!;
3435
save_positions = (true, true),
3536
affect_neg! = affect!,
3637
interp_points = 10,
37-
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
38+
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
39+
initializealg = nothing)
3840
```
3941
4042
Contains a single callback whose `condition` is a continuous function. The callback is triggered when this function evaluates to 0.
@@ -91,8 +93,26 @@ Contains a single callback whose `condition` is a continuous function. The callb
9193
- `repeat_nudge = 1//100`: This is used to set the next testing point after a
9294
previously found zero. Defaults to 1//100, which means after a callback, the next
9395
sign check will take place at t + dt*1//100 instead of at t to avoid repeats.
96+
- `initializealg = nothing`: In the context of a DAE, this is the algorithm that is used
97+
to run initialization after the effect. The default of `nothing` defers to the initialization
98+
algorithm provided in the `solve`.
99+
100+
!!! warn
101+
102+
The effect of using a callback with a DAE needs to be done with care because the solution
103+
`u` needs to satisfy the algebraic constraints before taking the next step. For this reason,
104+
a consistent initialization calculation must be run after running the callback. If the
105+
chosen initialization alg is `BrownBasicInit()` (the default for `solve`), then the initialization
106+
will change the algebraic variables to satisfy the conditions. Thus if `x` is an algebraic
107+
variable and the callback performs `x+=1`, the initialization may "revert" the change to
108+
satisfy the constraints. This behavior can be removed by setting `initializealg = CheckInit()`,
109+
which simply checks that the state `u` is consistent, but requires that the result of the
110+
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
111+
used as that will lead to an unstable step following initialization. This warning can be
112+
ignored for non-DAE ODEs.
94113
"""
95-
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <: AbstractContinuousCallback
114+
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
115+
AbstractContinuousCallback
96116
condition::F1
97117
affect!::F2
98118
affect_neg!::F3
@@ -106,19 +126,21 @@ struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <: AbstractContin
106126
abstol::T
107127
reltol::T2
108128
repeat_nudge::T3
129+
initializealg::T4
109130
function ContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3,
110131
initialize::F4, finalize::F5, idxs::I, rootfind,
111132
interp_points, save_positions, dtrelax::R, abstol::T,
112133
reltol::T2,
113-
repeat_nudge::T3) where {F1, F2, F3, F4, F5, T, T2, T3, I, R
134+
repeat_nudge::T3,
135+
initializealg::T4) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R
114136
}
115137
_condition = prepare_function(condition)
116-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, I, R}(_condition,
138+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
117139
affect!, affect_neg!,
118140
initialize, finalize, idxs, rootfind,
119141
interp_points,
120142
BitArray(collect(save_positions)),
121-
dtrelax, abstol, reltol, repeat_nudge)
143+
dtrelax, abstol, reltol, repeat_nudge, initializealg)
122144
end
123145
end
124146

@@ -131,12 +153,13 @@ function ContinuousCallback(condition, affect!, affect_neg!;
131153
interp_points = 10,
132154
dtrelax = 1,
133155
abstol = 10eps(), reltol = 0,
134-
repeat_nudge = 1 // 100)
156+
repeat_nudge = 1 // 100,
157+
initializealg = nothing)
135158
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize,
136159
idxs,
137160
rootfind, interp_points,
138161
save_positions,
139-
dtrelax, abstol, reltol, repeat_nudge)
162+
dtrelax, abstol, reltol, repeat_nudge, initializealg)
140163
end
141164

142165
function ContinuousCallback(condition, affect!;
@@ -148,11 +171,12 @@ function ContinuousCallback(condition, affect!;
148171
affect_neg! = affect!,
149172
interp_points = 10,
150173
dtrelax = 1,
151-
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
174+
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
175+
initializealg = nothing)
152176
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs,
153177
rootfind, interp_points,
154178
collect(save_positions),
155-
dtrelax, abstol, reltol, repeat_nudge)
179+
dtrelax, abstol, reltol, repeat_nudge, initializealg)
156180
end
157181

158182
"""
@@ -164,7 +188,8 @@ VectorContinuousCallback(condition, affect!, affect_neg!, len;
164188
rootfind = LeftRootFind,
165189
save_positions = (true, true),
166190
interp_points = 10,
167-
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
191+
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
192+
initializealg = nothing)
168193
```
169194
170195
```julia
@@ -176,7 +201,8 @@ VectorContinuousCallback(condition, affect!, len;
176201
save_positions = (true, true),
177202
affect_neg! = affect!,
178203
interp_points = 10,
179-
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
204+
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
205+
initializealg = nothing)
180206
```
181207
182208
This is also a subtype of `AbstractContinuousCallback`. `CallbackSet` is not feasible when you have many callbacks,
@@ -194,7 +220,7 @@ multiple events.
194220
195221
Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref).
196222
"""
197-
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <:
223+
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
198224
AbstractContinuousCallback
199225
condition::F1
200226
affect!::F2
@@ -210,20 +236,22 @@ struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <:
210236
abstol::T
211237
reltol::T2
212238
repeat_nudge::T3
239+
initializealg::T4
213240
function VectorContinuousCallback(
214241
condition::F1, affect!::F2, affect_neg!::F3, len::Int,
215242
initialize::F4, finalize::F5, idxs::I, rootfind,
216243
interp_points, save_positions, dtrelax::R,
217244
abstol::T, reltol::T2,
218-
repeat_nudge::T3) where {F1, F2, F3, F4, F5, T, T2,
219-
T3, I, R}
245+
repeat_nudge::T3,
246+
initializealg::T4) where {F1, F2, F3, F4, F5, T, T2,
247+
T3, T4, I, R}
220248
_condition = prepare_function(condition)
221-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, I, R}(_condition,
249+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
222250
affect!, affect_neg!, len,
223251
initialize, finalize, idxs, rootfind,
224252
interp_points,
225253
BitArray(collect(save_positions)),
226-
dtrelax, abstol, reltol, repeat_nudge)
254+
dtrelax, abstol, reltol, repeat_nudge, initializealg)
227255
end
228256
end
229257

@@ -235,13 +263,14 @@ function VectorContinuousCallback(condition, affect!, affect_neg!, len;
235263
save_positions = (true, true),
236264
interp_points = 10,
237265
dtrelax = 1,
238-
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
266+
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
267+
initializealg = nothing)
239268
VectorContinuousCallback(condition, affect!, affect_neg!, len,
240269
initialize, finalize,
241270
idxs,
242271
rootfind, interp_points,
243272
save_positions, dtrelax,
244-
abstol, reltol, repeat_nudge)
273+
abstol, reltol, repeat_nudge, initializealg)
245274
end
246275

247276
function VectorContinuousCallback(condition, affect!, len;
@@ -253,20 +282,22 @@ function VectorContinuousCallback(condition, affect!, len;
253282
affect_neg! = affect!,
254283
interp_points = 10,
255284
dtrelax = 1,
256-
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
285+
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
286+
initializealg = nothing)
257287
VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize,
258288
idxs,
259289
rootfind, interp_points,
260290
collect(save_positions),
261-
dtrelax, abstol, reltol, repeat_nudge)
291+
dtrelax, abstol, reltol, repeat_nudge, initializealg)
262292
end
263293

264294
"""
265295
```julia
266296
DiscreteCallback(condition, affect!;
267297
initialize = INITIALIZE_DEFAULT,
268298
finalize = FINALIZE_DEFAULT,
269-
save_positions = (true, true))
299+
save_positions = (true, true),
300+
initializealg = nothing)
270301
```
271302
272303
# Arguments
@@ -291,26 +322,48 @@ DiscreteCallback(condition, affect!;
291322
- `finalize`: This is a function `(c,u,t,integrator)` which can be used to finalize
292323
the state of the callback `c`. It should can the argument `c` and the return is
293324
ignored.
325+
- `initializealg = nothing`: In the context of a DAE, this is the algorithm that is used
326+
to run initialization after the effect. The default of `nothing` defers to the initialization
327+
algorithm provided in the `solve`.
328+
329+
!!! warn
330+
331+
The effect of using a callback with a DAE needs to be done with care because the solution
332+
`u` needs to satisfy the algebraic constraints before taking the next step. For this reason,
333+
a consistent initialization calculation must be run after running the callback. If the
334+
chosen initialization alg is `BrownBasicInit()` (the default for `solve`), then the initialization
335+
will change the algebraic variables to satisfy the conditions. Thus if `x` is an algebraic
336+
variable and the callback performs `x+=1`, the initialization may "revert" the change to
337+
satisfy the constraints. This behavior can be removed by setting `initializealg = CheckInit()`,
338+
which simply checks that the state `u` is consistent, but requires that the result of the
339+
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
340+
used as that will lead to an unstable step following initialization. This warning can be
341+
ignored for non-DAE ODEs.
294342
"""
295-
struct DiscreteCallback{F1, F2, F3, F4} <: AbstractDiscreteCallback
343+
struct DiscreteCallback{F1, F2, F3, F4, F5} <: AbstractDiscreteCallback
296344
condition::F1
297345
affect!::F2
298346
initialize::F3
299347
finalize::F4
300348
save_positions::BitArray{1}
349+
initializealg::F5
301350
function DiscreteCallback(condition::F1, affect!::F2,
302351
initialize::F3, finalize::F4,
303-
save_positions) where {F1, F2, F3, F4}
352+
save_positions,
353+
initializealg::F5) where {F1, F2, F3, F4, F5}
304354
_condition = prepare_function(condition)
305-
new{typeof(_condition), F2, F3, F4}(_condition,
355+
new{typeof(_condition), F2, F3, F4, F5}(_condition,
306356
affect!, initialize, finalize,
307-
BitArray(collect(save_positions)))
357+
BitArray(collect(save_positions)),
358+
initializealg)
308359
end
309360
end
310361
function DiscreteCallback(condition, affect!;
311362
initialize = INITIALIZE_DEFAULT, finalize = FINALIZE_DEFAULT,
312-
save_positions = (true, true))
313-
DiscreteCallback(condition, affect!, initialize, finalize, save_positions)
363+
save_positions = (true, true),
364+
initializealg = nothing)
365+
DiscreteCallback(
366+
condition, affect!, initialize, finalize, save_positions, initializealg)
314367
end
315368

316369
"""

src/ensemble/ensemble_solutions.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,12 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
111111
res = norm(m_final - m_final_analytic)
112112
weak_errors[:weak_final] = res
113113
if weak_timeseries_errors
114-
115114
if analyticvoa
116115
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic.u[i] for j in 1:length(u)])
117-
for i in 1:length(u[1])]
116+
for i in 1:length(u[1])]
118117
else
119118
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)])
120-
for i in 1:length(u[1])]
119+
for i in 1:length(u[1])]
121120
end
122121
ts_l2_errors = [sqrt.(sum(abs2, err) / length(err)) for err in ts_weak_errors]
123122
l2_tmp = sqrt(sum(abs2, ts_l2_errors) / length(ts_l2_errors))
@@ -128,8 +127,8 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
128127
if weak_dense_errors
129128
densetimes = collect(range(u[1].t[1], stop = u[1].t[end], length = 100))
130129
u_analytic = [[sol.prob.f.analytic(sol.prob.u0, sol.prob.p, densetimes[i],
131-
sol.W(densetimes[i])[1])
132-
for i in eachindex(densetimes)] for sol in u]
130+
sol.W(densetimes[i])[1])
131+
for i in eachindex(densetimes)] for sol in u]
133132

134133
udense = [u[j](densetimes) for j in 1:length(u)]
135134
dense_weak_errors = [mean([udense[j].u[i] - u_analytic[j][i] for j in 1:length(u)])

0 commit comments

Comments
 (0)