Skip to content

Commit f970ee7

Browse files
Merge pull request #885 from AayushSabharwal/as/init-infra
feat: add infrastructure for initialization of different problem types
2 parents b9ac7b5 + dfcb209 commit f970ee7

16 files changed

+538
-109
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ StableRNGs = "1.0"
8888
StaticArrays = "1.7"
8989
StaticArraysCore = "1.4"
9090
Statistics = "1.10"
91-
SymbolicIndexingInterface = "0.3.34"
91+
SymbolicIndexingInterface = "0.3.36"
9292
Tables = "1.11"
9393
Zygote = "0.6.67"
9494
julia = "1.10"

src/SciMLBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import FunctionWrappersWrappers
2222
import RuntimeGeneratedFunctions
2323
import EnumX
2424
import ADTypes: ADTypes, AbstractADType
25-
import Accessors: @set, @reset
25+
import Accessors: @set, @reset, @delete
2626
using Expronicon.ADT: @match
2727

2828
using Reexport

src/initialization.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,27 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
111111
return f(args...)
112112
end
113113

114+
"""
115+
Utility function to evaluate the RHS, adding extra arguments (such as history function for
116+
DDEs) wherever necessary.
117+
"""
118+
function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t)
119+
return _evaluate_f(integrator, f, isinplace, u, p, t)
120+
end
121+
122+
function evaluate_f(
123+
integrator::DEIntegrator, prob::AbstractDAEProblem, f, isinplace, u, p, t)
124+
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
125+
end
126+
127+
function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
128+
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
129+
end
130+
131+
function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t)
132+
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
133+
end
134+
114135
"""
115136
$(TYPEDSIGNATURES)
116137
@@ -147,7 +168,7 @@ function get_initial_values(
147168
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
148169
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
149170
update_coefficients!(M, u0, p, t)
150-
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
171+
tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
151172
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
152173

153174
normresid = isdefined(integrator.opts, :internalnorm) ?
@@ -165,7 +186,7 @@ function get_initial_values(
165186
p = parameter_values(integrator)
166187
t = current_time(integrator)
167188

168-
resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
189+
resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
169190
normresid = isdefined(integrator.opts, :internalnorm) ?
170191
integrator.opts.internalnorm(resid, t) : norm(resid)
171192

src/integrator_interface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,7 @@ function isadaptive(integrator::DEIntegrator)
925925
isdefined(integrator.opts, :adaptive) ? integrator.opts.adaptive : false
926926
end
927927

928-
function SymbolicIndexingInterface.get_history_function(integ::AbstractDDEIntegrator)
928+
function SymbolicIndexingInterface.get_history_function(integ::Union{
929+
AbstractDDEIntegrator, AbstractSDDEIntegrator})
929930
DDESolutionHistoryWrapper(get_sol(integ))
930931
end

src/problems/dde_problems.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,19 @@ struct DDEProblem{uType, tType, lType, lType2, isinplace, P, F, H, K, PT} <:
253253
end
254254
end
255255

256+
function ConstructionBase.constructorof(::Type{P}) where {P <: DDEProblem}
257+
function ctor(f, u0, h, tspan, p, constant_lags, dependent_lags,
258+
kw, neutral, order_discontinuity_t0, problem_type)
259+
if f isa AbstractDDEFunction
260+
iip = isinplace(f)
261+
else
262+
iip = isinplace(f, 5)
263+
end
264+
return DDEProblem{iip}(f, u0, h, tspan, p; kw..., constant_lags, dependent_lags,
265+
neutral, order_discontinuity_t0, problem_type)
266+
end
267+
end
268+
256269
DDEProblem(f, args...; kwargs...) = DDEProblem(DDEFunction(f), args...; kwargs...)
257270

258271
function DDEProblem(f::AbstractDDEFunction, args...; kwargs...)

src/problems/nonlinear_problems.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ function NonlinearProblem(f::AbstractODEFunction, u0, p = NullParameters(); kwar
222222
NonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
223223
end
224224

225+
function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearProblem}
226+
function ctor(f, u0, p, pt, kw)
227+
if f isa AbstractNonlinearFunction
228+
iip = isinplace(f)
229+
else
230+
iip = isinplace(f, 4)
231+
end
232+
return NonlinearProblem{iip}(f, u0, p, pt; kw...)
233+
end
234+
end
235+
225236
"""
226237
$(SIGNATURES)
227238
@@ -322,6 +333,17 @@ function NonlinearLeastSquaresProblem(f, u0, p = NullParameters(); kwargs...)
322333
return NonlinearLeastSquaresProblem(NonlinearFunction(f), u0, p; kwargs...)
323334
end
324335

336+
function ConstructionBase.constructorof(::Type{P}) where {P <: NonlinearLeastSquaresProblem}
337+
function ctor(f, u0, p, kw)
338+
if f isa AbstractNonlinearFunction
339+
iip = isinplace(f)
340+
else
341+
iip = isinplace(f, 4)
342+
end
343+
return NonlinearProblem{iip}(f, u0, p; kw...)
344+
end
345+
end
346+
325347
@doc doc"""
326348
SCCNonlinearProblem(probs, explicitfuns!)
327349

src/problems/sdde_problems.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,19 @@ end
157157
function SDDEProblem(f::AbstractSDDEFunction, args...; kwargs...)
158158
SDDEProblem{isinplace(f)}(f, args...; kwargs...)
159159
end
160+
161+
function ConstructionBase.constructorof(::Type{P}) where {P <: SDDEProblem}
162+
function ctor(f, g, u0, h, tspan, p, noise, constant_lags, dependent_lags, kw,
163+
noise_rate_prototype, seed, neutral, order_discontinuity_t0)
164+
if f isa AbstractSDDEFunction
165+
iip = isinplace(f)
166+
else
167+
iip = isinplace(f, 5)
168+
end
169+
return SDDEProblem{iip}(
170+
f, g, u0, h, tspan, p; kw..., noise, constant_lags, dependent_lags,
171+
noise_rate_prototype, seed, neutral, order_discontinuity_t0)
172+
end
173+
end
174+
175+
SymbolicIndexingInterface.get_history_function(prob::AbstractSDDEProblem) = prob.h

src/problems/sde_problems.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
125125
SDEProblem{iip}(SDEFunction{iip}(f, g), u0, tspan, p; kwargs...)
126126
end
127127

128+
function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem}
129+
function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed)
130+
if f isa AbstractSDEFunction
131+
iip = isinplace(f)
132+
else
133+
iip = isinplace(f, 4)
134+
end
135+
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
136+
end
137+
end
138+
128139
"""
129140
$(TYPEDEF)
130141
"""

0 commit comments

Comments
 (0)