Skip to content

Commit ba2820f

Browse files
feat: add trivial form of tearing to MTKBase's mtkcompile
1 parent 07fc4aa commit ba2820f

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

lib/ModelingToolkitBase/src/systems/nonlinear/initializesystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,15 @@ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works w
739739
initialization.
740740
"""
741741
function unhack_observed(obseqs, eqs)
742+
mask = trues(length(obseqs))
743+
for (i, eq) in enumerate(obseqs)
744+
mask[i] = Moshi.Match.@match eq.rhs begin
745+
BSImpl.Term(; f) => f !== offset_array
746+
_ => true
747+
end
748+
end
749+
750+
obseqs = obseqs[mask]
742751
return obseqs, eqs
743752
end
744753

lib/ModelingToolkitBase/src/systems/systems.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,19 @@ function __mtkcompile(sys::AbstractSystem;
175175
end
176176
# Nonlinear system
177177
if !has_derivatives && !has_shifts
178+
obseqs = Equation[]
179+
get_trivial_observed_equations!(Equation[], eqs, obseqs, all_dvs, nothing)
180+
add_array_observed!(obseqs)
181+
obseqs = topsort_equations(obseqs, [eq.lhs for eq in obseqs])
178182
map!(eq -> Symbolics.COMMON_ZERO ~ (eq.rhs - eq.lhs), eqs, eqs)
183+
observables = Set{SymbolicT}()
184+
for eq in obseqs
185+
push!(observables, eq.lhs)
186+
end
187+
setdiff!(flat_dvs, observables)
179188
@set! sys.eqs = eqs
180189
@set! sys.unknowns = flat_dvs
190+
@set! sys.observed = obseqs
181191
return sys
182192
end
183193
iv = get_iv(sys)::SymbolicT
@@ -284,6 +294,9 @@ function __mtkcompile(sys::AbstractSystem;
284294
BSImpl.Term(; args) => args[1]
285295
end)
286296
end
297+
get_trivial_observed_equations!(diffeqs, alg_eqs, obseqs, all_dvs, iv)
298+
add_array_observed!(obseqs)
299+
obseqs = topsort_equations(obseqs, [eq.lhs for eq in obseqs])
287300
for i in eachindex(alg_eqs)
288301
eq = alg_eqs[i]
289302
alg_eqs[i] = 0 ~ subst(eq.rhs - eq.lhs)
@@ -331,6 +344,125 @@ function __mtkcompile(sys::AbstractSystem;
331344
return sys
332345
end
333346

347+
"""
348+
$TYPEDSIGNATURES
349+
350+
For explicit algebraic equations in `algeqs`, find ones where the RHS is a function of
351+
differential variables or other observed variables. These equations are removed from
352+
`algeqs` and appended to `obseqs`. The process runs iteratively until a fixpoint is
353+
reached.
354+
"""
355+
function get_trivial_observed_equations!(diffeqs::Vector{Equation}, algeqs::Vector{Equation},
356+
obseqs::Vector{Equation}, all_dvs::Set{SymbolicT},
357+
@nospecialize(iv::Union{SymbolicT, Nothing}))
358+
# Maximum number of times to loop over all algebraic equations
359+
maxiters = 100
360+
# Whether it's worth doing another loop, or we already reached a fixpoint
361+
active = true
362+
363+
current_observed = Set{SymbolicT}()
364+
for eq in obseqs
365+
push!(current_observed, eq.lhs)
366+
end
367+
diffvars = Set{SymbolicT}()
368+
for eq in diffeqs
369+
push!(diffvars, Moshi.Match.@match eq.lhs begin
370+
BSImpl.Term(; f, args) && if f isa Union{Shift, Differential} end => args[1]
371+
end)
372+
end
373+
# Incidence information
374+
vars_in_each_algeq = Set{SymbolicT}[]
375+
sizehint!(vars_in_each_algeq, length(algeqs))
376+
for eq in algeqs
377+
buffer = Set{SymbolicT}()
378+
SU.search_variables!(buffer, eq.rhs)
379+
# We only care for variables
380+
intersect!(buffer, all_dvs)
381+
# If `eq.lhs` is only dependent on differential or other observed variables,
382+
# we can tear it. So we don't care about those either.
383+
setdiff!(buffer, diffvars)
384+
setdiff!(buffer, current_observed)
385+
if iv isa SymbolicT
386+
delete!(buffer, iv)
387+
end
388+
push!(vars_in_each_algeq, buffer)
389+
end
390+
# Algebraic equations that we still consider for elimination
391+
active_alg_eqs = trues(length(algeqs))
392+
# The number of equations we're considering for elimination
393+
candidate_eqs_count = length(algeqs)
394+
# Algebraic equations that we still consider algebraic
395+
alg_eqs_mask = trues(length(algeqs))
396+
# Observed variables added by this process
397+
new_observed_variables = Set{SymbolicT}()
398+
while active && maxiters > 0 && candidate_eqs_count > 0
399+
# We've reached a fixpoint unless the inner loop adds an observed equation
400+
active = false
401+
for i in eachindex(algeqs)
402+
# Ignore if we're not considering this for elimination or it is already eliminated
403+
active_alg_eqs[i] || continue
404+
alg_eqs_mask[i] || continue
405+
eq = algeqs[i]
406+
candidate_var = eq.lhs
407+
# LHS must be an unknown and must not be another observed
408+
if !(candidate_var in all_dvs) || candidate_var in new_observed_variables
409+
active_alg_eqs[i] = false
410+
candidate_eqs_count -= 1
411+
continue
412+
end
413+
# Remove newly added observed variables
414+
vars_in_algeq = vars_in_each_algeq[i]
415+
setdiff!(vars_in_algeq, new_observed_variables)
416+
# If the incidence is empty, it is a function of observed and diffvars
417+
isempty(vars_in_algeq) || continue
418+
419+
# We added an observed equation, so we haven't reached a fixpoint yet
420+
active = true
421+
push!(new_observed_variables, candidate_var)
422+
push!(obseqs, eq)
423+
# This is no longer considered for elimination
424+
active_alg_eqs[i] = false
425+
candidate_eqs_count -= 1
426+
# And is no longer algebraic
427+
alg_eqs_mask[i] = false
428+
end
429+
# Safeguard against infinite loops, because `while true` is potentially dangerous
430+
maxiters -= 1
431+
end
432+
433+
keepat!(algeqs, alg_eqs_mask)
434+
end
435+
436+
function offset_array(origin, arr)
437+
if all(isone, origin)
438+
return arr
439+
end
440+
return Origin(origin)(arr)
441+
end
442+
443+
@register_array_symbolic offset_array(origin::Any, arr::AbstractArray) begin
444+
size = size(arr)
445+
eltype = eltype(arr)
446+
ndims = ndims(arr)
447+
end
448+
449+
function add_array_observed!(obseqs::Vector{Equation})
450+
array_obsvars = Set{SymbolicT}()
451+
for eq in obseqs
452+
arr, isarr = split_indexed_var(eq.lhs)
453+
isarr && push!(array_obsvars, arr)
454+
end
455+
for var in array_obsvars
456+
firstind = first(SU.stable_eachindex(var))::SU.StableIndex{Int}
457+
firstind = Tuple(firstind.idxs)
458+
scal = SymbolicT[]
459+
for i in SU.stable_eachindex(var)
460+
push!(scal, var[i])
461+
end
462+
push!(obseqs, var ~ offset_array(firstind, reshape(scal, size(var))))
463+
end
464+
end
465+
334466
function simplify_sde_system(sys::AbstractSystem; kwargs...)
335467
brown_vars = brownians(sys)
336468
@set! sys.brownians = SymbolicT[]

0 commit comments

Comments
 (0)