Skip to content

Commit 9ecdae6

Browse files
feat: generalize CheckInit to DDEs
1 parent 089e31a commit 9ecdae6

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/initialization.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,27 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
111111
return f(args...)
112112
end
113113

114+
"""
115+
Utility function to evaluate the RHS, adding extra arguments (such as history function for
116+
DDEs) wherever necessary.
117+
"""
118+
function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t)
119+
return _evaluate_f(integrator, f, isinplace, u, p, t)
120+
end
121+
122+
function evaluate_f(
123+
integrator::DEIntegrator, prob::AbstractDAEProblem, f, isinplace, u, p, t)
124+
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
125+
end
126+
127+
function evaluate_f(integrator::AbstractDDEIntegrator, prob, f, isinplace, u, p, t)
128+
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
129+
end
130+
131+
function evaluate_f(integrator::AbstractSDDEIntegrator, prob, f, isinplace, u, p, t)
132+
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
133+
end
134+
114135
"""
115136
$(TYPEDSIGNATURES)
116137
@@ -147,7 +168,7 @@ function get_initial_values(
147168
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
148169
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
149170
update_coefficients!(M, u0, p, t)
150-
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
171+
tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
151172
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
152173

153174
normresid = isdefined(integrator.opts, :internalnorm) ?
@@ -165,7 +186,7 @@ function get_initial_values(
165186
p = parameter_values(integrator)
166187
t = current_time(integrator)
167188

168-
resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
189+
resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
169190
normresid = isdefined(integrator.opts, :internalnorm) ?
170191
integrator.opts.internalnorm(resid, t) : norm(resid)
171192

0 commit comments

Comments
 (0)