Skip to content

Commit b183af9

Browse files
feat: generalize CheckInit to DDEs
1 parent 089e31a commit b183af9

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

src/initialization.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,26 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
111111
return f(args...)
112112
end
113113

114+
"""
115+
Utility function to evaluate the RHS, adding extra arguments (such as history function for
116+
DDEs) wherever necessary.
117+
"""
118+
function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t)
119+
return _evaluate_f(integrator, f, isinplace, u, p, t)
120+
end
121+
122+
function evaluate_f(integrator::DEIntegrator, prob::AbtsractDAEProblem, f, isinplace, u, p, t)
123+
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
124+
end
125+
126+
function evaluate_f(integrator::AbstractDDEIntegrator, prob, f, isinplace, u, p, t)
127+
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
128+
end
129+
130+
function evaluate_f(integrator::AbstractSDDEIntegrator, prob, f, isinplace, u, p, t)
131+
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
132+
end
133+
114134
"""
115135
$(TYPEDSIGNATURES)
116136
@@ -147,7 +167,7 @@ function get_initial_values(
147167
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
148168
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
149169
update_coefficients!(M, u0, p, t)
150-
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
170+
tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
151171
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
152172

153173
normresid = isdefined(integrator.opts, :internalnorm) ?
@@ -165,7 +185,7 @@ function get_initial_values(
165185
p = parameter_values(integrator)
166186
t = current_time(integrator)
167187

168-
resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
188+
resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
169189
normresid = isdefined(integrator.opts, :internalnorm) ?
170190
integrator.opts.internalnorm(resid, t) : norm(resid)
171191

0 commit comments

Comments
 (0)