Skip to content

Commit cd67a51

Browse files
committed
fix: update GeneralDomain to the new ManifoldProjection API
1 parent 7b52768 commit cd67a51

File tree

5 files changed

+112
-82
lines changed

5 files changed

+112
-82
lines changed

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
67
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
78

89
[compat]
10+
ADTypes = "1.9.0"
911
DiffEqCallbacks = "3"
1012
Documenter = "1"
1113
OrdinaryDiffEq = "6.88"

docs/src/projection.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ManifoldProjection
1212
Here we solve the harmonic oscillator:
1313

1414
```@example manifold
15-
using OrdinaryDiffEq, DiffEqCallbacks, Plots
15+
using OrdinaryDiffEq, DiffEqCallbacks, Plots, ADTypes
1616
1717
u0 = ones(2)
1818
function f(du, u, p, t)
@@ -28,14 +28,13 @@ to conserve the sum of squares:
2828
```@example manifold
2929
function g(resid, u, p, t)
3030
resid[1] = u[2]^2 + u[1]^2 - 2
31-
resid[2] = 0
3231
end
3332
```
3433

3534
To build the callback, we just call
3635

3736
```@example manifold
38-
cb = ManifoldProjection(g)
37+
cb = ManifoldProjection(g; autodiff = AutoForwardDiff(), resid_prototype = zeros(1))
3938
```
4039

4140
Using this callback, the Runge-Kutta method `Vern7` conserves energy. Note that the

src/domain.jl

Lines changed: 78 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,35 @@
22

33
abstract type AbstractDomainAffect{T, S, uType} end
44

5+
(f::AbstractDomainAffect)(integrator) = affect!(integrator, f)
6+
57
struct PositiveDomainAffect{T, S, uType} <: AbstractDomainAffect{T, S, uType}
68
abstol::T
79
scalefactor::S
810
u::uType
911
end
1012

11-
struct GeneralDomainAffect{autonomous, F, T, S, uType} <: AbstractDomainAffect{T, S, uType}
13+
struct GeneralDomainAffect{F <: AbstractNonAutonomousFunction, T, S, uType, A} <:
14+
AbstractDomainAffect{T, S, uType}
1215
g::F
1316
abstol::T
1417
scalefactor::S
1518
u::uType
1619
resid::uType
20+
autonomous::A
21+
end
1722

18-
function GeneralDomainAffect{autonomous}(g::F, abstol::T, scalefactor::S, u::uType,
19-
resid::uType) where {autonomous, F, T, S, uType
20-
}
21-
new{autonomous, F, T, S, uType}(g, abstol, scalefactor, u, resid)
23+
function initialize_general_domain_affect(cb, u, t, integrator)
24+
return initialize_general_domain_affect(cb.affect!, u, t, integrator)
25+
end
26+
function initialize_general_domain_affect(affect!::GeneralDomainAffect, u, t, integrator)
27+
if affect!.autonomous === nothing
28+
autonomous = maximum(SciMLBase.numargs(affect!.g.f)) ==
29+
2 + SciMLBase.isinplace(integrator.f)
30+
affect!.g.autonomous = autonomous
2231
end
2332
end
2433

25-
# definitions of callback functions
26-
27-
# Workaround since it is not possible to add methods to an abstract type:
28-
# https://github.com/JuliaLang/julia/issues/14919
29-
(f::PositiveDomainAffect)(integrator) = affect!(integrator, f)
30-
(f::GeneralDomainAffect)(integrator) = affect!(integrator, f)
31-
3234
# general method definitions for domain callbacks
3335

3436
"""
@@ -41,6 +43,8 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
4143
throw(ArgumentError("domain callback can only be applied to adaptive algorithms"))
4244
end
4345

46+
iip = Val(SciMLBase.isinplace(integrator.f))
47+
4448
# define array of next time step, absolute tolerance, and scale factor
4549
if uType <: Nothing
4650
if integrator.u isa Union{Number, StaticArraysCore.SArray}
@@ -55,7 +59,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
5559
scalefactor = S <: Nothing ? 1 // 2 : f.scalefactor
5660

5761
# setup callback and save additional arguments for checking next time step
58-
args = setup(f, integrator)
62+
args = setup(f, integrator, iip)
5963

6064
# obtain proposed next time step
6165
dt = get_proposed_dt(integrator)
@@ -80,7 +84,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
8084
end
8185

8286
# check whether time step is accepted
83-
isaccepted(u, p, t, abstol, f, args...) && break
87+
isaccepted(u, p, t, abstol, f, iip, args...) && break
8488

8589
# reduce time step
8690
dtcache = dt
@@ -120,20 +124,20 @@ was modified.
120124
modify_u!(integrator, ::AbstractDomainAffect) = false
121125

122126
"""
123-
setup(f::AbstractDomainAffect, integrator)
127+
setup(f::AbstractDomainAffect, integrator, ::Val{iip}) where {iip}
124128
125129
Setup callback `f` and return an arbitrary tuple whose elements are used as additional
126130
arguments in checking whether time step is accepted.
127131
"""
128-
setup(::AbstractDomainAffect, integrator) = ()
132+
setup(::AbstractDomainAffect, integrator, ::Val{iip}) where {iip} = ()
129133

130134
"""
131135
isaccepted(u, abstol, f::AbstractDomainAffect, args...)
132136
133137
Return whether `u` is an acceptable state vector at the next time point given absolute
134138
tolerance `abstol`, callback `f`, and other optional arguments.
135139
"""
136-
isaccepted(u, p, t, tolerance, ::AbstractDomainAffect, args...) = true
140+
isaccepted(u, p, t, tolerance, ::AbstractDomainAffect, ::Val{iip}, args...) where {iip} = true
137141

138142
# specific method definitions for positive domain callback
139143

@@ -175,27 +179,30 @@ function _set_neg_zero!(integrator, u::StaticArraysCore.SArray)
175179
end
176180

177181
# state vector is accepted if its entries are greater than -abstol
178-
isaccepted(u, p, t, abstol::Number, ::PositiveDomainAffect) = all(ui -> ui > -abstol, u)
179-
function isaccepted(u, p, t, abstol, ::PositiveDomainAffect)
182+
function isaccepted(u, p, t, abstol::Number, ::PositiveDomainAffect, ::Val{iip}) where {iip}
183+
return all(ui -> ui > -abstol, u)
184+
end
185+
function isaccepted(u, p, t, abstol, ::PositiveDomainAffect, ::Val{iip}) where {iip}
180186
length(u) == length(abstol) ||
181187
throw(DimensionMismatch("numbers of states and tolerances do not match"))
182-
all(ui > -tol for (ui, tol) in zip(u, abstol))
188+
return all(ui > -tol for (ui, tol) in zip(u, abstol))
183189
end
184190

185191
# specific method definitions for general domain callback
186192

187193
# create array of residuals
188-
function setup(f::GeneralDomainAffect, integrator)
189-
f.resid isa Nothing ? (similar(integrator.u),) : (f.resid,)
194+
setup(f::GeneralDomainAffect, integrator, ::Val{false}) = (nothing,)
195+
function setup(f::GeneralDomainAffect, integrator, ::Val{true})
196+
return f.resid === nothing ? (similar(integrator.u),) : (f.resid,)
190197
end
191198

192-
function isaccepted(u, p, t, abstol, f::GeneralDomainAffect{autonomous, F, T, S, uType},
193-
resid) where {autonomous, F, T, S, uType}
199+
function isaccepted(u, p, t, abstol, f::GeneralDomainAffect, ::Val{iip}, resid) where {iip}
194200
# calculate residuals
195-
if autonomous
201+
f.g.t = t
202+
if iip
196203
f.g(resid, u, p)
197204
else
198-
f.g(resid, u, p, t)
205+
resid = f.g(u, p)
199206
end
200207

201208
# accept time step if residuals are smaller than the tolerance
@@ -214,26 +221,32 @@ end
214221
"""
215222
GeneralDomain(
216223
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
217-
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
218-
abstol = 10 * eps()), kwargs...)
224+
autonomous = nothing, domain_jacobian = nothing,
225+
nlsolve_kwargs = (; abstol = 10 * eps()), kwargs...)
219226
220227
A `GeneralDomain` callback in DiffEqCallbacks.jl generalizes the concept of
221-
a `PositiveDomain` callback to arbitrary domains. Domains are specified by
222-
in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` that calculate residuals of a
223-
state vector `u` at time `t` relative to that domain, with `p` the parameters of the
224-
corresponding integrator. As for `PositiveDomain`, steps are accepted if residuals
225-
of the extrapolated values at the next time step are below
226-
a certain tolerance. Moreover, this callback is automatically coupled with a
227-
`ManifoldProjection` that keeps all calculated state vectors close to the desired
228-
domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a
229-
`ManifoldProjection` cannot guarantee that all state vectors of the solution are
230-
actually inside the domain. Thus, a `PositiveDomain` callback should generally be
231-
preferred.
228+
a `PositiveDomain` callback to arbitrary domains.
229+
230+
Domains are specified by
231+
- in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` if the corresponding
232+
ODEProblem is an inplace problem, or
233+
- out-of-place functions `g(u, p)` or `g(u, p, t)` if the corresponding ODEProblem is
234+
an out-of-place problem.
235+
236+
The function calculates residuals of a state vector `u` at time `t` relative to that domain,
237+
with `p` the parameters of the corresponding integrator.
238+
239+
As for `PositiveDomain`, steps are accepted if residuals of the extrapolated values at the
240+
next time step are below a certain tolerance. Moreover, this callback is automatically
241+
coupled with a `ManifoldProjection` that keeps all calculated state vectors close to the
242+
desired domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a
243+
`ManifoldProjection` cannot guarantee that all state vectors of the solution are actually
244+
inside the domain. Thus, a `PositiveDomain` callback should generally be preferred.
232245
233246
## Arguments
234247
235-
- `g`: the implicit definition of the domain as a function `g(resid, u, p)` or
236-
`g(resid, u, p, t)` which is zero when the value is in the domain.
248+
- `g`: the implicit definition of the domain as a function as described above which is
249+
zero when the value is in the domain.
237250
- `u`: A prototype of the state vector of the integrator. A copy of it is saved and
238251
extrapolated values are written to it. If it is not specified,
239252
every application of the callback allocates a new copy of the state vector.
@@ -248,9 +261,13 @@ preferred.
248261
specified, time steps are halved.
249262
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)`.
250263
If it is not specified, it is determined automatically.
251-
- `kwargs`: All other keyword arguments are passed to `ManifoldProjection`.
264+
- `kwargs`: All other keyword arguments are passed to [`ManifoldProjection`](@ref).
252265
- `nlsolve_kwargs`: All keyword arguments are passed to the nonlinear solver in
253266
`ManifoldProjection`. The default is `(; abstol = 10 * eps())`.
267+
- `domain_jacobian`: The Jacobian of the domain (wrt the state). This has the same
268+
signature as `g` and the first argument is the Jacobian if inplace. This corresponds to
269+
the `manifold_jacobian` argument of [`ManifoldProjection`](@ref). Note that passing
270+
a `manifold_jacobian` is not supported for `GeneralDomain` and results in an error.
254271
255272
## References
256273
@@ -260,20 +277,27 @@ Non-negative solutions of ODEs. Applied Mathematics and Computation 170
260277
"""
261278
function GeneralDomain(
262279
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
263-
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
264-
abstol = 10 * eps()), kwargs...)
265-
_autonomous = SciMLBase._unwrap_val(autonomous)
266-
if u isa Nothing
267-
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, nothing, nothing)
280+
autonomous = nothing, domain_jacobian = nothing, manifold_jacobian = missing,
281+
nlsolve_kwargs = (; abstol = 10 * eps()), kwargs...)
282+
if manifold_jacobian !== missing
283+
throw(ArgumentError("`manifold_jacobian` is not supported for `GeneralDomain`. \
284+
Use `domain_jacobian` instead."))
285+
end
286+
manifold_projection = ManifoldProjection(
287+
g; save = false, autonomous, manifold_jacobian = domain_jacobian,
288+
kwargs..., nlsolve_kwargs...)
289+
domain = wrap_autonomous_function(autonomous, g)
290+
domain_jacobian = wrap_autonomous_function(autonomous, domain_jacobian)
291+
affect! = if u === nothing
292+
GeneralDomainAffect(domain, abstol, scalefactor, nothing, nothing, autonomous)
268293
else
269-
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, deepcopy(u),
270-
deepcopy(u))
294+
GeneralDomainAffect(
295+
domain, abstol, scalefactor, deepcopy(u), deepcopy(u), autonomous)
271296
end
272-
condition = (u, t, integrator) -> true
273-
CallbackSet(
274-
ManifoldProjection(
275-
g; save = false, autonomous, isinplace = Val(true), kwargs..., nlsolve_kwargs...),
276-
DiscreteCallback(condition, affect!; save_positions = (false, save)))
297+
domain_cb = DiscreteCallback(
298+
Returns(true), affect!; initialize = initialize_general_domain_affect,
299+
save_positions = (false, save))
300+
return CallbackSet(manifold_projection, domain_cb)
277301
end
278302

279303
@doc doc"""

src/manifold.jl

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ properties.
3131
would work in most cases (See [1] for details). Alternatively, a nonlinear solver as
3232
defined in the
3333
[NonlinearSolve.jl format](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/)
34-
can be specified.
34+
can be specified. Additionally if NonlinearSolve.jl is loaded and `nothing` is specified
35+
a polyalgorithm is used.
3536
- `save`: Whether to do the standard saving (applied after the callback)
3637
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)` or
3738
`g(u, p)`. Specify it as `Val(::Bool)` to disable runtime branching. If `nothing`,
@@ -88,25 +89,8 @@ end
8889

8990
function ManifoldProjection(
9091
manifold, autodiff, manifold_jacobian, nlsolve, kwargs, autonomous)
91-
if autonomous isa Val{true} || autonomous isa Val{false}
92-
wrapped_manifold = TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}(
93-
manifold, nothing)
94-
wrapped_manifold_jacobian = if manifold_jacobian === nothing
95-
nothing
96-
else
97-
TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}(
98-
manifold_jacobian, nothing)
99-
end
100-
autonomous = SciMLBase._unwrap_val(autonomous)
101-
else
102-
_autonomous = autonomous === nothing ? false : autonomous
103-
wrapped_manifold = UntypedNonAutonomousFunction(_autonomous, manifold, nothing)
104-
wrapped_manifold_jacobian = if manifold_jacobian === nothing
105-
nothing
106-
else
107-
UntypedNonAutonomousFunction(_autonomous, manifold_jacobian, nothing)
108-
end
109-
end
92+
wrapped_manifold = wrap_autonomous_function(autonomous, manifold)
93+
wrapped_manifold_jacobian = wrap_autonomous_function(autonomous, manifold_jacobian)
11094
return ManifoldProjection(wrapped_manifold, wrapped_manifold_jacobian,
11195
autodiff, nothing, nlsolve, kwargs, autonomous)
11296
end
@@ -158,7 +142,20 @@ end
158142
export ManifoldProjection
159143

160144
# wrapper for non-autonomous functions
161-
@concrete mutable struct TypedNonAutonomousFunction{autonomous}
145+
function wrap_autonomous_function(autonomous::Union{Val{true}, Val{false}}, g)
146+
g === nothing && return nothing
147+
return TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}(g, nothing)
148+
end
149+
function wrap_autonomous_function(autonomous::Union{Bool, Nothing}, g)
150+
g === nothing && return nothing
151+
autonomous = autonomous === nothing ? false : autonomous
152+
return UntypedNonAutonomousFunction(autonomous, g, nothing)
153+
end
154+
155+
abstract type AbstractNonAutonomousFunction end
156+
157+
@concrete mutable struct TypedNonAutonomousFunction{autonomous} <:
158+
AbstractNonAutonomousFunction
162159
f
163160
t::Any
164161
end
@@ -169,7 +166,7 @@ end
169166
(f::TypedNonAutonomousFunction{false})(u, p) = f.f(u, p, f.t)
170167
(f::TypedNonAutonomousFunction{true})(u, p) = f.f(u, p)
171168

172-
@concrete mutable struct UntypedNonAutonomousFunction
169+
@concrete mutable struct UntypedNonAutonomousFunction <: AbstractNonAutonomousFunction
173170
autonomous::Bool
174171
f
175172
t::Any

test/domain_tests.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DiffEqCallbacks, OrdinaryDiffEq, Test
1+
using DiffEqCallbacks, OrdinaryDiffEq, Test, ADTypes, NonlinearSolve
22

33
# Non-negative ODE examples
44
#
@@ -39,7 +39,11 @@ naive_sol_absval = solve(prob_absval, BS3())
3939
function g(resid, u, p)
4040
resid[1] = u[1] < 0 ? -u[1] : 0
4141
end
42-
general_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g, [1.0]),
42+
general_sol_absval = solve(
43+
prob_absval, BS3();
44+
callback = GeneralDomain(g, [1.0];
45+
autodiff = AutoForwardDiff(),
46+
nlsolve=NewtonRaphson(; autodiff = AutoForwardDiff())),
4347
save_everystep = false)
4448
@test all(x -> x[1] 0, general_sol_absval.u)
4549
@test general_sol_absval.errors[:l∞] < 9.9e-5
@@ -49,7 +53,11 @@ general_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g, [1.0]
4953
# test "non-autonomous" function
5054
g_t(resid, u, p, t) = g(resid, u, p)
5155

52-
general_t_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g_t, [1.0]),
56+
general_t_sol_absval = solve(
57+
prob_absval, BS3();
58+
callback = GeneralDomain(g_t, [1.0];
59+
autodiff = AutoForwardDiff(),
60+
nlsolve=NewtonRaphson(; autodiff = AutoForwardDiff())),
5361
save_everystep = false)
5462
@test general_sol_absval.t general_t_sol_absval.t
5563
@test general_sol_absval.u general_t_sol_absval.u

0 commit comments

Comments
 (0)