Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
kwargs1 = merge(kwargs1, (; tstops))
end

return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
# Call `remake` so it runs initialization if it is trivial
return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should at least be commented why it's done. It's a bit of an odd way to get there, but understand why

end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

Expand Down Expand Up @@ -963,8 +964,10 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
kwargs1 = merge(kwargs1, (; tstops))
end

DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
kwargs..., kwargs1...)
# Call `remake` so it runs initialization if it is trivial
return remake(DAEProblem{iip}(
f, du0, u0, tspan, p; differential_vars = differential_vars,
kwargs..., kwargs1...))
end

function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
Expand All @@ -991,7 +994,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
end
f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
symbolic_u0 = true,
symbolic_u0 = true, u0_constructor,
check_length, eval_expression, eval_module, kwargs...)
h_gen = generate_history(sys, u0; expression = Val{true})
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
Expand All @@ -1008,7 +1011,8 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
# Call `remake` so it runs initialization if it is trivial
return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...))
end

function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
Expand All @@ -1029,7 +1033,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
end
f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
symbolic_u0 = true, eval_expression, eval_module,
symbolic_u0 = true, eval_expression, eval_module, u0_constructor,
check_length, kwargs...)
h_gen = generate_history(sys, u0; expression = Val{true})
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
Expand Down Expand Up @@ -1057,9 +1061,10 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
else
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
end
SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
# Call `remake` so it runs initialization if it is trivial
return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
noise_rate_prototype =
noise_rate_prototype, kwargs1..., kwargs...)
noise_rate_prototype, kwargs1..., kwargs...))
end

"""
Expand Down
5 changes: 3 additions & 2 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,9 @@ function DiffEqBase.SDEProblem{iip, specialize}(

kwargs = filter_kwargs(kwargs)

SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
noise_rate_prototype = noise_rate_prototype, kwargs...)
# Call `remake` so it runs initialization if it is trivial
return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
noise_rate_prototype = noise_rate_prototype, kwargs...))
end

function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
pt = something(get_metadata(sys), StandardNonlinearProblem())
NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)
# Call `remake` so it runs initialization if it is trivial
return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...))
end

"""
Expand Down Expand Up @@ -548,7 +549,8 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
pt = something(get_metadata(sys), StandardNonlinearProblem())
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
# Call `remake` so it runs initialization if it is trivial
return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...))
end

const TypeT = Union{DataType, UnionAll}
Expand Down
11 changes: 8 additions & 3 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ All other keyword arguments are forwarded to `InitializationProblem`.
"""
function maybe_build_initialization_problem(
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
guesses, missing_unknowns; implicit_dae = false, kwargs...)
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...)
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))

if t === nothing && is_time_dependent(sys)
Expand All @@ -615,7 +615,7 @@ function maybe_build_initialization_problem(
if is_time_dependent(sys)
all_init_syms = Set(all_symbols(initializeprob))
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
initializeprobmap = getu(initializeprob, solved_unknowns)
initializeprobmap = u0_constructor ∘ getu(initializeprob, solved_unknowns)
else
initializeprobmap = nothing
end
Expand Down Expand Up @@ -774,14 +774,19 @@ function process_SciMLProblem(
op, missing_unknowns, missing_pars = build_operating_point!(sys,
u0map, pmap, defs, cmap, dvs, ps)

if u0_constructor === identity && u0Type <: StaticArray
u0_constructor = vals -> SymbolicUtils.Code.create_array(
u0Type, eltype(vals), Val(1), Val(length(vals)), vals...)
end
if build_initializeprob
kws = maybe_build_initialization_problem(
sys, op, u0map, pmap, t, defs, guesses, missing_unknowns;
implicit_dae, warn_initialize_determined, initialization_eqs,
eval_expression, eval_module, fully_determined,
warn_cyclic_dependency, check_units = check_initialization_units,
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete)
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
u0_constructor)

kwargs = merge(kwargs, kws)
end
Expand Down
32 changes: 30 additions & 2 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ end
@parameters x0 y0
@mtkbuild sys = ODESystem([x ~ x0, y ~ y0, s ~ x + y], t; guesses = [y0 => 0.0])
prob = ODEProblem(sys, [s => 1.0], (0.0, 1.0), [x0 => 0.3, y0 => missing])
@test prob.ps[y0] ≈ 0.0
# trivial initialization run immediately
@test prob.ps[y0] ≈ 0.7
@test init(prob, Tsit5()).ps[y0] ≈ 0.7
@test solve(prob, Tsit5()).ps[y0] ≈ 0.7
end
Expand All @@ -745,7 +746,8 @@ end
systems = [fixed, spring, mass, gravity, constant, damper],
guesses = [spring.s_rel0 => 1.0])
prob = ODEProblem(sys, [], (0.0, 1.0), [spring.s_rel0 => missing])
@test prob.ps[spring.s_rel0] ≈ 0.0
# trivial initialization run immediately
@test prob.ps[spring.s_rel0] ≈ -3.905
@test init(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905
@test solve(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905
end
Expand Down Expand Up @@ -1388,3 +1390,29 @@ end
integ1 = init(oprob1)
@test integ1[X1] ≈ 1.0
end

@testset "Trivial initialization is run on problem construction" begin
@variables _x(..) y(t)
@brownian a
@parameters tot
x = _x(t)
@testset "$Problem" for (Problem, lhs, rhs) in [
(ODEProblem, D, 0.0),
(SDEProblem, D, a),
(DDEProblem, D, _x(t - 0.1)),
(SDDEProblem, D, _x(t - 0.1) + a)
]
@mtkbuild sys = ModelingToolkit.System([lhs(x) ~ x + rhs, x + y ~ tot], t;
guesses = [tot => 1.0], defaults = [tot => missing])
prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
@test prob.ps[tot] ≈ 2.0
end
@testset "$Problem" for Problem in [NonlinearProblem, NonlinearLeastSquaresProblem]
@parameters p1 p2
@mtkbuild sys = NonlinearSystem([x^2 + y^2 ~ p1, (x - 1)^2 + (y - 1)^2 ~ p2];
parameter_dependencies = [p2 ~ 2p1],
guesses = [p1 => 0.0], defaults = [p1 => missing])
prob = Problem(sys, [x => 1.0, y => 1.0], [p2 => 6.0])
@test prob.ps[p1] ≈ 3.0
end
end
Loading