2
2
3
3
abstract type AbstractDomainAffect{T, S, uType} end
4
4
5
+ (f:: AbstractDomainAffect )(integrator) = affect! (integrator, f)
6
+
5
7
struct PositiveDomainAffect{T, S, uType} <: AbstractDomainAffect{T, S, uType}
6
8
abstol:: T
7
9
scalefactor:: S
8
10
u:: uType
9
11
end
10
12
11
- struct GeneralDomainAffect{autonomous, F, T, S, uType} <: AbstractDomainAffect{T, S, uType}
13
+ struct GeneralDomainAffect{F <: AbstractNonAutonomousFunction , T, S, uType, A} < :
14
+ AbstractDomainAffect{T, S, uType}
12
15
g:: F
13
16
abstol:: T
14
17
scalefactor:: S
15
18
u:: uType
16
19
resid:: uType
20
+ autonomous:: A
21
+ end
17
22
18
- function GeneralDomainAffect {autonomous} (g:: F , abstol:: T , scalefactor:: S , u:: uType ,
19
- resid:: uType ) where {autonomous, F, T, S, uType
20
- }
21
- new {autonomous, F, T, S, uType} (g, abstol, scalefactor, u, resid)
23
+ function initialize_general_domain_affect (cb, u, t, integrator)
24
+ return initialize_general_domain_affect (cb. affect!, u, t, integrator)
25
+ end
26
+ function initialize_general_domain_affect (affect!:: GeneralDomainAffect , u, t, integrator)
27
+ if affect!. autonomous === nothing
28
+ autonomous = maximum (SciMLBase. numargs (affect!. g. f)) ==
29
+ 2 + SciMLBase. isinplace (integrator. f)
30
+ affect!. g. autonomous = autonomous
22
31
end
23
32
end
24
33
25
- # definitions of callback functions
26
-
27
- # Workaround since it is not possible to add methods to an abstract type:
28
- # https://github.com/JuliaLang/julia/issues/14919
29
- (f:: PositiveDomainAffect )(integrator) = affect! (integrator, f)
30
- (f:: GeneralDomainAffect )(integrator) = affect! (integrator, f)
31
-
32
34
# general method definitions for domain callbacks
33
35
34
36
"""
@@ -41,6 +43,8 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
41
43
throw (ArgumentError (" domain callback can only be applied to adaptive algorithms" ))
42
44
end
43
45
46
+ iip = Val (SciMLBase. isinplace (integrator. f))
47
+
44
48
# define array of next time step, absolute tolerance, and scale factor
45
49
if uType <: Nothing
46
50
if integrator. u isa Union{Number, StaticArraysCore. SArray}
@@ -55,7 +59,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
55
59
scalefactor = S <: Nothing ? 1 // 2 : f. scalefactor
56
60
57
61
# setup callback and save additional arguments for checking next time step
58
- args = setup (f, integrator)
62
+ args = setup (f, integrator, iip )
59
63
60
64
# obtain proposed next time step
61
65
dt = get_proposed_dt (integrator)
@@ -80,7 +84,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
80
84
end
81
85
82
86
# check whether time step is accepted
83
- isaccepted (u, p, t, abstol, f, args... ) && break
87
+ isaccepted (u, p, t, abstol, f, iip, args... ) && break
84
88
85
89
# reduce time step
86
90
dtcache = dt
@@ -120,20 +124,20 @@ was modified.
120
124
modify_u! (integrator, :: AbstractDomainAffect ) = false
121
125
122
126
"""
123
- setup(f::AbstractDomainAffect, integrator)
127
+ setup(f::AbstractDomainAffect, integrator, ::Val{iip}) where {iip}
124
128
125
129
Setup callback `f` and return an arbitrary tuple whose elements are used as additional
126
130
arguments in checking whether time step is accepted.
127
131
"""
128
- setup (:: AbstractDomainAffect , integrator) = ()
132
+ setup (:: AbstractDomainAffect , integrator, :: Val{iip} ) where {iip} = ()
129
133
130
134
"""
131
135
isaccepted(u, abstol, f::AbstractDomainAffect, args...)
132
136
133
137
Return whether `u` is an acceptable state vector at the next time point given absolute
134
138
tolerance `abstol`, callback `f`, and other optional arguments.
135
139
"""
136
- isaccepted (u, p, t, tolerance, :: AbstractDomainAffect , args... ) = true
140
+ isaccepted (u, p, t, tolerance, :: AbstractDomainAffect , :: Val{iip} , args... ) where {iip} = true
137
141
138
142
# specific method definitions for positive domain callback
139
143
@@ -175,27 +179,30 @@ function _set_neg_zero!(integrator, u::StaticArraysCore.SArray)
175
179
end
176
180
177
181
# state vector is accepted if its entries are greater than -abstol
178
- isaccepted (u, p, t, abstol:: Number , :: PositiveDomainAffect ) = all (ui -> ui > - abstol, u)
179
- function isaccepted (u, p, t, abstol, :: PositiveDomainAffect )
182
+ function isaccepted (u, p, t, abstol:: Number , :: PositiveDomainAffect , :: Val{iip} ) where {iip}
183
+ return all (ui -> ui > - abstol, u)
184
+ end
185
+ function isaccepted (u, p, t, abstol, :: PositiveDomainAffect , :: Val{iip} ) where {iip}
180
186
length (u) == length (abstol) ||
181
187
throw (DimensionMismatch (" numbers of states and tolerances do not match" ))
182
- all (ui > - tol for (ui, tol) in zip (u, abstol))
188
+ return all (ui > - tol for (ui, tol) in zip (u, abstol))
183
189
end
184
190
185
191
# specific method definitions for general domain callback
186
192
187
193
# create array of residuals
188
- function setup (f:: GeneralDomainAffect , integrator)
189
- f. resid isa Nothing ? (similar (integrator. u),) : (f. resid,)
194
+ setup (f:: GeneralDomainAffect , integrator, :: Val{false} ) = (nothing ,)
195
+ function setup (f:: GeneralDomainAffect , integrator, :: Val{true} )
196
+ return f. resid === nothing ? (similar (integrator. u),) : (f. resid,)
190
197
end
191
198
192
- function isaccepted (u, p, t, abstol, f:: GeneralDomainAffect{autonomous, F, T, S, uType} ,
193
- resid) where {autonomous, F, T, S, uType}
199
+ function isaccepted (u, p, t, abstol, f:: GeneralDomainAffect , :: Val{iip} , resid) where {iip}
194
200
# calculate residuals
195
- if autonomous
201
+ f. g. t = t
202
+ if iip
196
203
f. g (resid, u, p)
197
204
else
198
- f. g (resid, u, p, t )
205
+ resid = f. g (u, p)
199
206
end
200
207
201
208
# accept time step if residuals are smaller than the tolerance
@@ -214,26 +221,32 @@ end
214
221
"""
215
222
GeneralDomain(
216
223
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
217
- autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
218
- abstol = 10 * eps()), kwargs...)
224
+ autonomous = nothing, domain_jacobian = nothing,
225
+ nlsolve_kwargs = (; abstol = 10 * eps()), kwargs...)
219
226
220
227
A `GeneralDomain` callback in DiffEqCallbacks.jl generalizes the concept of
221
- a `PositiveDomain` callback to arbitrary domains. Domains are specified by
222
- in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` that calculate residuals of a
223
- state vector `u` at time `t` relative to that domain, with `p` the parameters of the
224
- corresponding integrator. As for `PositiveDomain`, steps are accepted if residuals
225
- of the extrapolated values at the next time step are below
226
- a certain tolerance. Moreover, this callback is automatically coupled with a
227
- `ManifoldProjection` that keeps all calculated state vectors close to the desired
228
- domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a
229
- `ManifoldProjection` cannot guarantee that all state vectors of the solution are
230
- actually inside the domain. Thus, a `PositiveDomain` callback should generally be
231
- preferred.
228
+ a `PositiveDomain` callback to arbitrary domains.
229
+
230
+ Domains are specified by
231
+ - in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` if the corresponding
232
+ ODEProblem is an inplace problem, or
233
+ - out-of-place functions `g(u, p)` or `g(u, p, t)` if the corresponding ODEProblem is
234
+ an out-of-place problem.
235
+
236
+ The function calculates residuals of a state vector `u` at time `t` relative to that domain,
237
+ with `p` the parameters of the corresponding integrator.
238
+
239
+ As for `PositiveDomain`, steps are accepted if residuals of the extrapolated values at the
240
+ next time step are below a certain tolerance. Moreover, this callback is automatically
241
+ coupled with a `ManifoldProjection` that keeps all calculated state vectors close to the
242
+ desired domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a
243
+ `ManifoldProjection` cannot guarantee that all state vectors of the solution are actually
244
+ inside the domain. Thus, a `PositiveDomain` callback should generally be preferred.
232
245
233
246
## Arguments
234
247
235
- - `g`: the implicit definition of the domain as a function `g(resid, u, p)` or
236
- `g(resid, u, p, t)` which is zero when the value is in the domain.
248
+ - `g`: the implicit definition of the domain as a function as described above which is
249
+ zero when the value is in the domain.
237
250
- `u`: A prototype of the state vector of the integrator. A copy of it is saved and
238
251
extrapolated values are written to it. If it is not specified,
239
252
every application of the callback allocates a new copy of the state vector.
@@ -248,9 +261,13 @@ preferred.
248
261
specified, time steps are halved.
249
262
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)`.
250
263
If it is not specified, it is determined automatically.
251
- - `kwargs`: All other keyword arguments are passed to `ManifoldProjection`.
264
+ - `kwargs`: All other keyword arguments are passed to [ `ManifoldProjection`](@ref) .
252
265
- `nlsolve_kwargs`: All keyword arguments are passed to the nonlinear solver in
253
266
`ManifoldProjection`. The default is `(; abstol = 10 * eps())`.
267
+ - `domain_jacobian`: The Jacobian of the domain (wrt the state). This has the same
268
+ signature as `g` and the first argument is the Jacobian if inplace. This corresponds to
269
+ the `manifold_jacobian` argument of [`ManifoldProjection`](@ref). Note that passing
270
+ a `manifold_jacobian` is not supported for `GeneralDomain` and results in an error.
254
271
255
272
## References
256
273
@@ -260,20 +277,27 @@ Non-negative solutions of ODEs. Applied Mathematics and Computation 170
260
277
"""
261
278
function GeneralDomain (
262
279
g, u = nothing ; save = true , abstol = nothing , scalefactor = nothing ,
263
- autonomous = maximum (SciMLBase. numargs (g)) == 3 , nlsolve_kwargs = (;
264
- abstol = 10 * eps ()), kwargs... )
265
- _autonomous = SciMLBase. _unwrap_val (autonomous)
266
- if u isa Nothing
267
- affect! = GeneralDomainAffect {_autonomous} (g, abstol, scalefactor, nothing , nothing )
280
+ autonomous = nothing , domain_jacobian = nothing , manifold_jacobian = missing ,
281
+ nlsolve_kwargs = (; abstol = 10 * eps ()), kwargs... )
282
+ if manifold_jacobian != = missing
283
+ throw (ArgumentError (" `manifold_jacobian` is not supported for `GeneralDomain`. \
284
+ Use `domain_jacobian` instead." ))
285
+ end
286
+ manifold_projection = ManifoldProjection (
287
+ g; save = false , autonomous, manifold_jacobian = domain_jacobian,
288
+ kwargs... , nlsolve_kwargs... )
289
+ domain = wrap_autonomous_function (autonomous, g)
290
+ domain_jacobian = wrap_autonomous_function (autonomous, domain_jacobian)
291
+ affect! = if u === nothing
292
+ GeneralDomainAffect (domain, abstol, scalefactor, nothing , nothing , autonomous)
268
293
else
269
- affect! = GeneralDomainAffect {_autonomous} (g, abstol, scalefactor, deepcopy (u),
270
- deepcopy (u))
294
+ GeneralDomainAffect (
295
+ domain, abstol, scalefactor, deepcopy (u), deepcopy (u), autonomous )
271
296
end
272
- condition = (u, t, integrator) -> true
273
- CallbackSet (
274
- ManifoldProjection (
275
- g; save = false , autonomous, isinplace = Val (true ), kwargs... , nlsolve_kwargs... ),
276
- DiscreteCallback (condition, affect!; save_positions = (false , save)))
297
+ domain_cb = DiscreteCallback (
298
+ Returns (true ), affect!; initialize = initialize_general_domain_affect,
299
+ save_positions = (false , save))
300
+ return CallbackSet (manifold_projection, domain_cb)
277
301
end
278
302
279
303
@doc doc"""
0 commit comments