Skip to content

Commit 3928194

Browse files
refactor: use initialization_data in SciMLFunction constructors
1 parent af8cd67 commit 3928194

File tree

4 files changed

+21
-33
lines changed

4 files changed

+21
-33
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
359359
sparsity = false,
360360
analytic = nothing,
361361
split_idxs = nothing,
362-
initializeprob = nothing,
363-
update_initializeprob! = nothing,
364-
initializeprobmap = nothing,
365-
initializeprobpmap = nothing,
362+
initialization_data = nothing,
366363
kwargs...) where {iip, specialize}
367364
if !iscomplete(sys)
368365
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
@@ -463,10 +460,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
463460
observed = observedfun,
464461
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
465462
analytic = analytic,
466-
initializeprob = initializeprob,
467-
update_initializeprob! = update_initializeprob!,
468-
initializeprobmap = initializeprobmap,
469-
initializeprobpmap = initializeprobpmap)
463+
initialization_data)
470464
end
471465

472466
"""
@@ -496,10 +490,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
496490
sparse = false, simplify = false,
497491
eval_module = @__MODULE__,
498492
checkbounds = false,
499-
initializeprob = nothing,
500-
initializeprobmap = nothing,
501-
initializeprobpmap = nothing,
502-
update_initializeprob! = nothing,
493+
initialization_data = nothing,
503494
kwargs...) where {iip}
504495
if !iscomplete(sys)
505496
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
@@ -547,15 +538,12 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
547538
nothing
548539
end
549540

550-
DAEFunction{iip}(f,
541+
DAEFunction{iip}(f;
551542
sys = sys,
552543
jac = _jac === nothing ? nothing : _jac,
553544
jac_prototype = jac_prototype,
554545
observed = observedfun,
555-
initializeprob = initializeprob,
556-
initializeprobmap = initializeprobmap,
557-
initializeprobpmap = initializeprobpmap,
558-
update_initializeprob! = update_initializeprob!)
546+
initialization_data)
559547
end
560548

561549
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -567,6 +555,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
567555
eval_expression = false,
568556
eval_module = @__MODULE__,
569557
checkbounds = false,
558+
initialization_data = nothing,
570559
kwargs...) where {iip}
571560
if !iscomplete(sys)
572561
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`")
@@ -579,7 +568,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
579568
f(u, h, p, t) = f_oop(u, h, p, t)
580569
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
581570

582-
DDEFunction{iip}(f, sys = sys)
571+
DDEFunction{iip}(f; sys = sys, initialization_data)
583572
end
584573

585574
function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -591,6 +580,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
591580
eval_expression = false,
592581
eval_module = @__MODULE__,
593582
checkbounds = false,
583+
initialization_data = nothing,
594584
kwargs...) where {iip}
595585
if !iscomplete(sys)
596586
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`")
@@ -609,7 +599,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
609599
g(u, h, p, t) = g_oop(u, h, p, t)
610600
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
611601

612-
SDDEFunction{iip}(f, g, sys = sys)
602+
SDDEFunction{iip}(f, g; sys = sys, initialization_data)
613603
end
614604

615605
"""

src/systems/diffeqs/sdesystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
544544
version = nothing, tgrad = false, sparse = false,
545545
jac = false, Wfact = false, eval_expression = false,
546546
eval_module = @__MODULE__,
547-
checkbounds = false,
547+
checkbounds = false, initialization_data = nothing,
548548
kwargs...) where {iip, specialize}
549549
if !iscomplete(sys)
550550
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
@@ -615,13 +615,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
615615

616616
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
617617

618-
SDEFunction{iip, specialize}(f, g,
618+
SDEFunction{iip, specialize}(f, g;
619619
sys = sys,
620620
jac = _jac === nothing ? nothing : _jac,
621621
tgrad = _tgrad === nothing ? nothing : _tgrad,
622622
Wfact = _Wfact === nothing ? nothing : _Wfact,
623623
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
624-
mass_matrix = _M,
624+
mass_matrix = _M, initialization_data,
625625
observed = observedfun)
626626
end
627627

src/systems/nonlinear/initializesystem.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,7 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p,
346346
u0map, pmap, defs, cmap, dvs, ps)
347347
kws = maybe_build_initialization_problem(
348348
sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc)
349-
initprob = get(kws, :initializeprob, nothing)
350-
if initprob === nothing
351-
return nothing
352-
end
353-
return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing),
354-
get(kws, :initializeprobmap, nothing),
355-
get(kws, :initializeprobpmap, nothing))
349+
return get(kws, :initialization_data, nothing)
356350
end
357351

358352
"""

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
344344
eval_expression = false,
345345
eval_module = @__MODULE__,
346346
sparse = false, simplify = false,
347+
initialization_data = nothing,
347348
kwargs...) where {iip}
348349
if !iscomplete(sys)
349350
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`")
@@ -376,14 +377,14 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
376377
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
377378
end
378379

379-
NonlinearFunction{iip}(f,
380+
NonlinearFunction{iip}(f;
380381
sys = sys,
381382
jac = _jac === nothing ? nothing : _jac,
382383
resid_prototype = resid_prototype,
383384
jac_prototype = sparse ?
384385
similar(calculate_jacobian(sys, sparse = sparse),
385386
Float64) : nothing,
386-
observed = observedfun)
387+
observed = observedfun, initialization_data)
387388
end
388389

389390
"""
@@ -395,7 +396,8 @@ respectively.
395396
"""
396397
function SciMLBase.IntervalNonlinearFunction(
397398
sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(sys), u0 = nothing;
398-
p = nothing, eval_expression = false, eval_module = @__MODULE__, kwargs...)
399+
p = nothing, eval_expression = false, eval_module = @__MODULE__,
400+
initialization_data = nothing, kwargs...)
399401
if !iscomplete(sys)
400402
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `IntervalNonlinearFunction`")
401403
end
@@ -411,7 +413,8 @@ function SciMLBase.IntervalNonlinearFunction(
411413

412414
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
413415

414-
IntervalNonlinearFunction{false}(f; observed = observedfun, sys = sys)
416+
IntervalNonlinearFunction{false}(
417+
f; observed = observedfun, sys = sys, initialization_data)
415418
end
416419

417420
"""
@@ -884,6 +887,7 @@ function flatten(sys::NonlinearSystem, noeqs = false)
884887
observed = observed(sys),
885888
defaults = defaults(sys),
886889
guesses = guesses(sys),
890+
initialization_eqs = initialization_equations(sys),
887891
name = nameof(sys),
888892
description = description(sys),
889893
metadata = get_metadata(sys),

0 commit comments

Comments
 (0)