Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0948934
refactor: do not build and use `paramsubs` in `generate_initializesys…
AayushSabharwal Jun 17, 2025
58a47ab
feat: add fast path for constructing `MTKParameters` in `process_SciM…
AayushSabharwal Jun 17, 2025
c8618c0
refactor: avoid unnecessary computation in `evaluate_varmap!`
AayushSabharwal Jun 17, 2025
9982ed9
refactor: avoid unnecesary computation in `build_operating_point!`
AayushSabharwal Jun 17, 2025
dd8b432
refactor: batch computation of temporary values in `maybe_build_initi…
AayushSabharwal Jun 17, 2025
29eee6d
feat: add an always-present mutable cache key to the system
AayushSabharwal Jun 19, 2025
b86cc52
fix: fix unnecessary warnings when no stream connections in `expand_c…
AayushSabharwal Jun 19, 2025
d3a7e71
feat: cache intermediate results for `observed_equations_used_by`
AayushSabharwal Jun 19, 2025
3a0abbb
fix: remove unnecessary scalarization in `InitializationProblem`
AayushSabharwal Jun 19, 2025
27702b6
refactor: add fast path in `build_operating_point!`
AayushSabharwal Jun 19, 2025
7db8dec
fix: do not rely on metadata in `process_parameter_equations`
AayushSabharwal Jun 19, 2025
5bfb2a9
refactor: remove source of allocations in `InitializationProblem`
AayushSabharwal Jun 20, 2025
fffc983
fix: properly handle values given to parameter dependencies in `late_…
AayushSabharwal Jun 20, 2025
7ec630d
feat: invalidate cache in `@set!`
AayushSabharwal Jun 20, 2025
d4079a8
refactor: update tests to account for new initsys generation
AayushSabharwal Jun 20, 2025
7709760
fix: handle metadata merging in `extend`
AayushSabharwal Jun 24, 2025
d975861
test: update metadata tests
AayushSabharwal Jun 24, 2025
a501375
feat: add the ability to completely remove vertices from `BipartiteGr…
AayushSabharwal Jun 24, 2025
2c04b7b
feat: preemptively tear some trivial equations in `mtkcompile`
AayushSabharwal Jun 19, 2025
b98e12f
ci: better handle compile time in benchmarks
AayushSabharwal Jun 24, 2025
faf8bce
ci: add benchmark for large parameter initialization model
AayushSabharwal Jun 24, 2025
a6df4cf
fix: ensure `initializeprobpmap` returns floats
AayushSabharwal Jun 25, 2025
dc0975f
feat: implement `SymbolicUtils.hasmetadata` for `AbstractSystem`
AayushSabharwal Jun 25, 2025
f4e81de
fix: fix potential infinite recursion in `simplify_optimization_system`
AayushSabharwal Jun 25, 2025
b0da478
test: mark test optimization test as no longer broken
AayushSabharwal Jun 25, 2025
2efd558
test: make optimization test robust to simplification
AayushSabharwal Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ModelingToolkitStandardLibrary.Electrical
using ModelingToolkitStandardLibrary.Mechanical.Rotational
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEqDefault
using ModelingToolkit: t_nounits as t, D_nounits as D

const SUITE = BenchmarkGroup()

Expand Down Expand Up @@ -45,12 +46,33 @@ end

@named model = DCMotor()

# first call
mtkcompile(model)
SUITE["mtkcompile"] = @benchmarkable mtkcompile($model)

model = mtkcompile(model)
u0 = unknowns(model) .=> 0.0
tspan = (0.0, 6.0)
SUITE["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)

prob = ODEProblem(model, u0, tspan)
SUITE["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)

# first call
init(prob)
SUITE["init"] = @benchmarkable init($prob)

large_param_init = SUITE["large_parameter_init"] = BenchmarkGroup()

N = 25
@variables x(t)[1:N]
@parameters A[1:N, 1:N]

defval = collect(x) * collect(x)'
@mtkcompile model = System(
[D(x) ~ x], t, [x], [A]; defaults = [A => defval], guesses = [A => fill(NaN, N, N)])

u0 = [x => rand(N)]
prob = ODEProblem(model, u0, tspan)
large_param_init["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)

large_param_init["init"] = @benchmarkable init($prob)
30 changes: 28 additions & 2 deletions src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,39 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors)
end
end

function delete_srcs!(g::BipartiteGraph, srcs)
function delete_srcs!(g::BipartiteGraph{I}, srcs; rm_verts = false) where {I}
for s in srcs
set_neighbors!(g, s, ())
end
if rm_verts
old_to_new_idxs = collect(one(I):I(nsrcs(g)))
for s in srcs
old_to_new_idxs[s] = zero(I)
end
offset = zero(I)
for i in eachindex(old_to_new_idxs)
if iszero(old_to_new_idxs[i])
offset += one(I)
continue
end
old_to_new_idxs[i] -= offset
end

if g.badjlist isa AbstractVector
for i in 1:ndsts(g)
for j in eachindex(g.badjlist[i])
g.badjlist[i][j] = old_to_new_idxs[g.badjlist[i][j]]
end
filter!(!iszero, g.badjlist[i])
end
end
deleteat!(g.fadjlist, srcs)
end
g
end
delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs)
function delete_dsts!(g::BipartiteGraph, srcs; rm_verts = false)
delete_srcs!(invview(g), srcs; rm_verts)
end

###
### Edges iteration
Expand Down
4 changes: 2 additions & 2 deletions src/problems/initializationproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const
for k in keys(op)
has_u0_ics |= is_variable(sys, k) || isdifferential(k) ||
symbolic_type(k) == ArraySymbolic() &&
is_sized_array_symbolic(k) && is_variable(sys, first(collect(k)))
is_sized_array_symbolic(k) && is_variable(sys, unwrap(first(wrap(k))))
end
if !has_u0_ics && get_initializesystem(sys) !== nothing
isys = get_initializesystem(sys; initialization_eqs, check_units)
Expand Down Expand Up @@ -79,7 +79,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const
@warn errmsg
end

uninit = setdiff(unknowns(sys), [unknowns(isys); observables(isys)])
uninit = setdiff(unknowns(sys), unknowns(isys), observables(isys))

# TODO: throw on uninitialized arrays
filter!(x -> !(x isa Symbolics.Arr), uninit)
Expand Down
3 changes: 2 additions & 1 deletion src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,8 @@ function update_simplified_system!(
obs_sub[eq.lhs] = eq.rhs
end
# TODO: compute the dependency correctly so that we don't have to do this
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs;
fast_substitute(state.additional_observed, obs_sub)]

unknown_idxs = filter(
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))
Expand Down
34 changes: 31 additions & 3 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,23 @@ has_equations(::AbstractSystem) = true

Invalidate cached jacobians, etc.
"""
invalidate_cache!(sys::AbstractSystem) = sys
function invalidate_cache!(sys::AbstractSystem)
has_metadata(sys) || return sys
empty!(getmetadata(sys, MutableCacheKey, nothing))
return sys
end

# `::MetadataT` but that is defined later
function refreshed_metadata(meta::Base.ImmutableDict)
newmeta = MetadataT()
for (k, v) in meta
if k === MutableCacheKey
v = MutableCacheT()
end
newmeta = Base.ImmutableDict(newmeta, k => v)
end
return newmeta
end

function Setfield.get(obj::AbstractSystem, ::Setfield.PropertyLens{field}) where {field}
getfield(obj, field)
Expand All @@ -815,6 +831,8 @@ end
args = map(fieldnames(obj)) do fn
if fn in fieldnames(patch)
:(patch.$fn)
elseif fn == :metadata
:($refreshed_metadata(getfield(obj, $(Meta.quot(fn)))))
else
:(getfield(obj, $(Meta.quot(fn))))
end
Expand Down Expand Up @@ -2507,7 +2525,15 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
meta = merge(get_metadata(basesys), get_metadata(sys))
meta = MetadataT()
for kvp in get_metadata(basesys)
kvp[1] == MutableCacheKey && continue
meta = Base.ImmutableDict(meta, kvp)
end
for kvp in get_metadata(sys)
kvp[1] == MutableCacheKey && continue
meta = Base.ImmutableDict(meta, kvp)
end
syss = union(get_systems(basesys), get_systems(sys))
args = length(ivs) == 0 ? (eqs, sts, ps) : (eqs, ivs[1], sts, ps)
kwargs = (observed = obs, continuous_events = cevs,
Expand Down Expand Up @@ -2705,7 +2731,9 @@ function process_parameter_equations(sys::AbstractSystem)
is_sized_array_symbolic(sym) &&
all(Base.Fix1(is_parameter, sys), collect(sym))
end
if !isparameter(eq.lhs)
# Everything in `varsbuf` is a parameter, so this is a cheap `is_parameter`
# check.
if !(eq.lhs in varsbuf)
throw(ArgumentError("""
LHS of parameter dependency equation must be a single parameter. Found \
$(eq.lhs).
Expand Down
3 changes: 2 additions & 1 deletion src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,8 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
eqs = [equations(sys); ceqs; stream_eqs]
# substitute `instream(..)` expressions with their new values
for i in eachindex(eqs)
eqs[i] = fixpoint_sub(eqs[i], instream_subs; maxiters = length(instream_subs))
eqs[i] = fixpoint_sub(
eqs[i], instream_subs; maxiters = max(length(instream_subs), 10))
end
# get the defaults for domain networks
d_defs = domain_defaults(sys, domain_csets)
Expand Down
106 changes: 54 additions & 52 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,20 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
end

# 5) process parameters as initialization unknowns
paramsubs = setup_parameter_initialization!(
solved_params = setup_parameter_initialization!(
sys, pmap, defs, guesses, eqs_ics; check_defguess)

# 6) parameter dependencies become equations, their LHS become unknowns
# non-numeric dependent parameters stay as parameter dependencies
new_parameter_deps = solve_parameter_dependencies!(
sys, paramsubs, eqs_ics, defs, guesses)
sys, solved_params, eqs_ics, defs, guesses)

# 7) handle values provided for dependent parameters similar to values for observed variables
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics, paramsubs)
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics)

# parameters do not include ones that became initialization unknowns
pars = Vector{SymbolicParam}(filter(
p -> !haskey(paramsubs, p), parameters(sys; initial_parameters = true)))
!in(solved_params), parameters(sys; initial_parameters = true)))
push!(pars, get_iv(sys))

# 8) use observed equations for guesses of observed variables if not provided
Expand All @@ -198,16 +198,8 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
end
append!(eqs_ics, trueobs)

vars = [vars; collect(values(paramsubs))]
vars = [vars; collect(solved_params)]

# even if `p => tovar(p)` is in `paramsubs`, `isparameter(p[1]) === true` after substitution
# so add scalarized versions as well
scalarize_varmap!(paramsubs)

eqs_ics = Symbolics.substitute.(eqs_ics, (paramsubs,))
for k in keys(defs)
defs[k] = substitute(defs[k], paramsubs)
end
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
merge!(defs, initials)
isys = System(Vector{Equation}(eqs_ics),
Expand Down Expand Up @@ -299,30 +291,22 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
append!(eqs_ics, initialization_eqs)

# process parameters as initialization unknowns
paramsubs = setup_parameter_initialization!(
solved_params = setup_parameter_initialization!(
sys, pmap, defs, guesses, eqs_ics; check_defguess)

# parameter dependencies become equations, their LHS become unknowns
# non-numeric dependent parameters stay as parameter dependencies
new_parameter_deps = solve_parameter_dependencies!(
sys, paramsubs, eqs_ics, defs, guesses)
sys, solved_params, eqs_ics, defs, guesses)

# handle values provided for dependent parameters similar to values for observed variables
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics, paramsubs)
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics)

# parameters do not include ones that became initialization unknowns
pars = Vector{SymbolicParam}(filter(
p -> !haskey(paramsubs, p), parameters(sys; initial_parameters = true)))
vars = collect(values(paramsubs))

# even if `p => tovar(p)` is in `paramsubs`, `isparameter(p[1]) === true` after substitution
# so add scalarized versions as well
scalarize_varmap!(paramsubs)
!in(solved_params), parameters(sys; initial_parameters = true)))
vars = collect(solved_params)

eqs_ics = Vector{Equation}(Symbolics.substitute.(eqs_ics, (paramsubs,)))
for k in keys(defs)
defs[k] = substitute(defs[k], paramsubs)
end
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
merge!(defs, initials)
isys = System(Vector{Equation}(eqs_ics),
Expand Down Expand Up @@ -359,7 +343,7 @@ mapping solvable parameters to their `tovar` variants.
function setup_parameter_initialization!(
sys::AbstractSystem, pmap::AbstractDict, defs::AbstractDict,
guesses::AbstractDict, eqs_ics::Vector{Equation}; check_defguess = false)
paramsubs = Dict()
solved_params = Set()
for p in parameters(sys)
if is_parameter_solvable(p, pmap, defs, guesses)
# If either of them are `missing` the parameter is an unknown
Expand All @@ -369,7 +353,7 @@ function setup_parameter_initialization!(
_val2 = get_possibly_array_fallback_singletons(defs, p)
_val3 = get_possibly_array_fallback_singletons(guesses, p)
varp = tovar(p)
paramsubs[p] = varp
push!(solved_params, p)
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
if _val2 === missing
if _val1 !== nothing && _val1 !== missing
Expand Down Expand Up @@ -409,7 +393,7 @@ function setup_parameter_initialization!(
end
end

return paramsubs
return solved_params
end

"""
Expand All @@ -418,7 +402,7 @@ end
Add appropriate parameter dependencies as initialization equations. Return the new list of
parameter dependencies for the initialization system.
"""
function solve_parameter_dependencies!(sys::AbstractSystem, paramsubs::AbstractDict,
function solve_parameter_dependencies!(sys::AbstractSystem, solved_params::AbstractSet,
eqs_ics::Vector{Equation}, defs::AbstractDict, guesses::AbstractDict)
new_parameter_deps = Equation[]
for eq in parameter_dependencies(sys)
Expand All @@ -427,7 +411,7 @@ function solve_parameter_dependencies!(sys::AbstractSystem, paramsubs::AbstractD
continue
end
varp = tovar(eq.lhs)
paramsubs[eq.lhs] = varp
push!(solved_params, eq.lhs)
push!(eqs_ics, eq)
guessval = get(guesses, eq.lhs, eq.rhs)
push!(defs, varp => guessval)
Expand All @@ -442,10 +426,10 @@ end
Turn values provided for parameter dependencies into initialization equations.
"""
function handle_dependent_parameter_constraints!(sys::AbstractSystem, pmap::AbstractDict,
eqs_ics::Vector{Equation}, paramsubs::AbstractDict)
eqs_ics::Vector{Equation})
for (k, v) in merge(defaults(sys), pmap)
if is_variable_floatingpoint(k) && has_parameter_dependency_with_lhs(sys, k)
push!(eqs_ics, paramsubs[k] ~ v)
push!(eqs_ics, k ~ v)
end
end

Expand Down Expand Up @@ -735,7 +719,25 @@ function SciMLBase.late_binding_update_u0_p(
newu0, newp = promote_u0_p(newu0, newp, t0)

# non-symbolic u0 updates initials...
if !(eltype(u0) <: Pair)
if eltype(u0) <: Pair
syms = []
vals = []
allsyms = all_symbols(sys)
for (k, v) in u0
v === nothing && continue
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
if k isa Symbol
k2 = symbol_to_symbolic(sys, k; allsyms)
# if it is returned as-is, there is no match so skip it
k2 === k && continue
k = k2
end
is_parameter(sys, Initial(k)) || continue
push!(syms, Initial(k))
push!(vals, v)
end
newp = setp_oop(sys, syms)(newp, vals)
else
# if `p` is not provided or is symbolic
p === missing || eltype(p) <: Pair || return newu0, newp
(newu0 === nothing || isempty(newu0)) && return newu0, newp
Expand All @@ -748,27 +750,27 @@ function SciMLBase.late_binding_update_u0_p(
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
end
newp = meta.set_initial_unknowns!(newp, newu0)
return newu0, newp
end

syms = []
vals = []
allsyms = all_symbols(sys)
for (k, v) in u0
v === nothing && continue
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
if k isa Symbol
k2 = symbol_to_symbolic(sys, k; allsyms)
# if it is returned as-is, there is no match so skip it
k2 === k && continue
k = k2
end

if eltype(p) <: Pair
syms = []
vals = []
for (k, v) in p
v === nothing && continue
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
if k isa Symbol
k2 = symbol_to_symbolic(sys, k; allsyms)
# if it is returned as-is, there is no match so skip it
k2 === k && continue
k = k2
end
is_parameter(sys, Initial(k)) || continue
push!(syms, Initial(k))
push!(vals, v)
end
is_parameter(sys, Initial(k)) || continue
push!(syms, Initial(k))
push!(vals, v)
newp = setp_oop(sys, syms)(newp, vals)
end

newp = setp_oop(sys, syms)(newp, vals)
return newu0, newp
end

Expand Down
Loading
Loading