Skip to content

Commit da00fed

Browse files
feat: separate trivially identifiable observed equations in initialization system
Greatly reduces `mtkcompile` time
1 parent 3945384 commit da00fed

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,19 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
198198
end
199199
append!(eqs_ics, trueobs)
200200

201+
# 9) optimize equations that are guaranteed to be observed
202+
eqs_ics, explicit_observed = separate_trivial_equations(sys, eqs_ics)
203+
eliminated_vars = Set([eq.lhs for eq in explicit_observed])
204+
201205
vars = [vars; collect(solved_params)]
206+
filter!(!in(eliminated_vars), vars)
202207

203208
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
204209
merge!(defs, initials)
205210
isys = System(Vector{Equation}(eqs_ics),
206211
vars,
207212
pars;
213+
observed = explicit_observed,
208214
defaults = defs,
209215
checks = check_units,
210216
name,
@@ -307,17 +313,26 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
307313
!in(solved_params), parameters(sys; initial_parameters = true)))
308314
vars = collect(solved_params)
309315

316+
# optimize equations that are guaranteed to be observed
317+
eqs_ics, explicit_observed = separate_trivial_equations(sys, eqs_ics)
318+
eliminated_vars = Set([eq.lhs for eq in explicit_observed])
319+
320+
vars = [vars; collect(solved_params)]
321+
filter!(!in(eliminated_vars), vars)
322+
310323
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
311324
merge!(defs, initials)
312325
isys = System(Vector{Equation}(eqs_ics),
313326
vars,
314327
pars;
328+
observed = explicit_observed,
315329
defaults = defs,
316330
checks = check_units,
317331
name,
318332
is_initializesystem = true,
319333
kwargs...)
320334
@set isys.parameter_dependencies = new_parameter_deps
335+
return isys
321336
end
322337

323338
"""
@@ -470,6 +485,54 @@ function filter_delay_equations_variables!(sys::AbstractSystem, trueobs::Vector{
470485
return deleteat!(trueobs, idxs_to_remove)
471486
end
472487

488+
function separate_trivial_equations(sys::AbstractSystem, eqs::Vector{Equation})
489+
# Many equations (observed, ICs for differential variables) are explicit and will be
490+
# converted to observed by `mtkcompile`, but doing so is also expensive. All such
491+
# equations are explicit in `eqs_ics`, so we find all variables determined by such
492+
# explicit equations exactly once (e.g. an observed variable given an initial condition
493+
# will occur twice on the LHS, ruling it out). These are substituted into the rest so
494+
# the determined-ness of the system isn't affected and will be added back to the
495+
# simplified system as observed.
496+
occurrences = Dict{BasicSymbolic, Vector{Int}}()
497+
for (i, eq) in enumerate(eqs)
498+
symbolic_type(eq.lhs) == NotSymbolic() && continue
499+
# if we eliminate `x ~ Initial(x)` equations, then for overdetermined systems
500+
# redundant conditions end up being parameter equations in terms of `Initial(..)`
501+
# and don't play well with the rest of the infrastructure. So we only eliminate
502+
# parameter equations and observed.
503+
is_variable(sys, eq.lhs) && continue
504+
if iscall(eq.lhs)
505+
op = operation(eq.lhs)
506+
args = arguments(eq.lhs)
507+
if !(issym(op) ||
508+
op === getindex && (!iscall(args[1]) || issym(operation(args[1]))))
509+
continue
510+
end
511+
end
512+
buffer = get!(() -> Int[], occurrences, eq.lhs)
513+
push!(buffer, i)
514+
end
515+
516+
blacklist = BitSet()
517+
subrules = Dict()
518+
explicit_observed = Equation[]
519+
for (sym, idxs) in occurrences
520+
length(idxs) == 1 || continue
521+
idx = only(idxs)
522+
subrules[sym] = eqs[idx].rhs
523+
push!(blacklist, idx)
524+
push!(explicit_observed, sym ~ eqs[idx].rhs)
525+
end
526+
527+
new_eqs = Equation[]
528+
for (i, eq) in enumerate(eqs)
529+
i in blacklist && continue
530+
push!(new_eqs, fixpoint_sub(eq, subrules; maxiters = length(subrules)))
531+
end
532+
533+
return new_eqs, explicit_observed
534+
end
535+
473536
"""
474537
$(TYPEDSIGNATURES)
475538

0 commit comments

Comments
 (0)