Skip to content

Commit fd1d16d

Browse files
feat: separate trivially identifiable observed equations in initialization system
Greatly reduces `mtkcompile` time
1 parent 9584a4e commit fd1d16d

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,19 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
198198
end
199199
append!(eqs_ics, trueobs)
200200

201+
# 9) optimize equations that are guaranteed to be observed
202+
eqs_ics, explicit_observed = separate_trivial_equations(eqs_ics)
203+
eliminated_vars = Set([eq.lhs for eq in explicit_observed])
204+
201205
vars = [vars; collect(solved_params)]
206+
filter!(!in(eliminated_vars), vars)
202207

203208
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
204209
merge!(defs, initials)
205210
isys = System(Vector{Equation}(eqs_ics),
206211
vars,
207212
pars;
213+
observed = explicit_observed,
208214
defaults = defs,
209215
checks = check_units,
210216
name,
@@ -307,17 +313,26 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
307313
!in(solved_params), parameters(sys; initial_parameters = true)))
308314
vars = collect(solved_params)
309315

316+
# optimize equations that are guaranteed to be observed
317+
eqs_ics, explicit_observed = separate_trivial_equations(eqs_ics)
318+
eliminated_vars = Set([eq.lhs for eq in explicit_observed])
319+
320+
vars = [vars; collect(solved_params)]
321+
filter!(!in(eliminated_vars), vars)
322+
310323
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
311324
merge!(defs, initials)
312325
isys = System(Vector{Equation}(eqs_ics),
313326
vars,
314327
pars;
328+
observed = explicit_observed,
315329
defaults = defs,
316330
checks = check_units,
317331
name,
318332
is_initializesystem = true,
319333
kwargs...)
320334
@set isys.parameter_dependencies = new_parameter_deps
335+
return isys
321336
end
322337

323338
"""
@@ -470,6 +485,41 @@ function filter_delay_equations_variables!(sys::AbstractSystem, trueobs::Vector{
470485
return deleteat!(trueobs, idxs_to_remove)
471486
end
472487

488+
function separate_trivial_equations(eqs::Vector{Equation})
489+
# Many equations (observed, ICs for differential variables) are explicit and will be
490+
# converted to observed by `mtkcompile`, but doing so is also expensive. All such
491+
# equations are explicit in `eqs_ics`, so we find all variables determined by such
492+
# explicit equations exactly once (e.g. an observed variable given an initial condition
493+
# will occur twice on the LHS, ruling it out). These are substituted into the rest so
494+
# the determined-ness of the system isn't affected and will be added back to the
495+
# simplified system as observed.
496+
occurrences = Dict{BasicSymbolic, Vector{Int}}()
497+
for (i, eq) in enumerate(eqs)
498+
symbolic_type(eq.lhs) == NotSymbolic() && continue
499+
buffer = get!(() -> Int[], occurrences, eq.lhs)
500+
push!(buffer, i)
501+
end
502+
503+
blacklist = BitSet()
504+
subrules = Dict()
505+
explicit_observed = Equation[]
506+
for (sym, idxs) in occurrences
507+
length(idxs) == 1 || continue
508+
idx = only(idxs)
509+
subrules[sym] = eqs[idx].rhs
510+
push!(blacklist, idx)
511+
push!(explicit_observed, sym ~ eqs[idx].rhs)
512+
end
513+
514+
new_eqs = Equation[]
515+
for (i, eq) in enumerate(eqs)
516+
i in blacklist && continue
517+
push!(new_eqs, fixpoint_sub(eq, subrules; limit = length(subrules)))
518+
end
519+
520+
return new_eqs, explicit_observed
521+
end
522+
473523
"""
474524
$(TYPEDSIGNATURES)
475525

0 commit comments

Comments
 (0)