@@ -175,9 +175,19 @@ function __mtkcompile(sys::AbstractSystem;
175175 end
176176 # Nonlinear system
177177 if ! has_derivatives && ! has_shifts
178+ obseqs = Equation[]
179+ get_trivial_observed_equations! (Equation[], eqs, obseqs, all_dvs, nothing )
180+ add_array_observed! (obseqs)
181+ obseqs = topsort_equations (obseqs, [eq. lhs for eq in obseqs])
178182 map! (eq -> Symbolics. COMMON_ZERO ~ (eq. rhs - eq. lhs), eqs, eqs)
183+ observables = Set {SymbolicT} ()
184+ for eq in obseqs
185+ push! (observables, eq. lhs)
186+ end
187+ setdiff! (flat_dvs, observables)
179188 @set! sys. eqs = eqs
180189 @set! sys. unknowns = flat_dvs
190+ @set! sys. observed = obseqs
181191 return sys
182192 end
183193 iv = get_iv (sys):: SymbolicT
@@ -284,6 +294,9 @@ function __mtkcompile(sys::AbstractSystem;
284294 BSImpl. Term (; args) => args[1 ]
285295 end )
286296 end
297+ get_trivial_observed_equations! (diffeqs, alg_eqs, obseqs, all_dvs, iv)
298+ add_array_observed! (obseqs)
299+ obseqs = topsort_equations (obseqs, [eq. lhs for eq in obseqs])
287300 for i in eachindex (alg_eqs)
288301 eq = alg_eqs[i]
289302 alg_eqs[i] = 0 ~ subst (eq. rhs - eq. lhs)
@@ -331,6 +344,125 @@ function __mtkcompile(sys::AbstractSystem;
331344 return sys
332345end
333346
347+ """
348+ $TYPEDSIGNATURES
349+
350+ For explicit algebraic equations in `algeqs`, find ones where the RHS is a function of
351+ differential variables or other observed variables. These equations are removed from
352+ `algeqs` and appended to `obseqs`. The process runs iteratively until a fixpoint is
353+ reached.
354+ """
355+ function get_trivial_observed_equations! (diffeqs:: Vector{Equation} , algeqs:: Vector{Equation} ,
356+ obseqs:: Vector{Equation} , all_dvs:: Set{SymbolicT} ,
357+ @nospecialize (iv:: Union{SymbolicT, Nothing} ))
358+ # Maximum number of times to loop over all algebraic equations
359+ maxiters = 100
360+ # Whether it's worth doing another loop, or we already reached a fixpoint
361+ active = true
362+
363+ current_observed = Set {SymbolicT} ()
364+ for eq in obseqs
365+ push! (current_observed, eq. lhs)
366+ end
367+ diffvars = Set {SymbolicT} ()
368+ for eq in diffeqs
369+ push! (diffvars, Moshi. Match. @match eq. lhs begin
370+ BSImpl. Term (; f, args) && if f isa Union{Shift, Differential} end => args[1 ]
371+ end )
372+ end
373+ # Incidence information
374+ vars_in_each_algeq = Set{SymbolicT}[]
375+ sizehint! (vars_in_each_algeq, length (algeqs))
376+ for eq in algeqs
377+ buffer = Set {SymbolicT} ()
378+ SU. search_variables! (buffer, eq. rhs)
379+ # We only care for variables
380+ intersect! (buffer, all_dvs)
381+ # If `eq.lhs` is only dependent on differential or other observed variables,
382+ # we can tear it. So we don't care about those either.
383+ setdiff! (buffer, diffvars)
384+ setdiff! (buffer, current_observed)
385+ if iv isa SymbolicT
386+ delete! (buffer, iv)
387+ end
388+ push! (vars_in_each_algeq, buffer)
389+ end
390+ # Algebraic equations that we still consider for elimination
391+ active_alg_eqs = trues (length (algeqs))
392+ # The number of equations we're considering for elimination
393+ candidate_eqs_count = length (algeqs)
394+ # Algebraic equations that we still consider algebraic
395+ alg_eqs_mask = trues (length (algeqs))
396+ # Observed variables added by this process
397+ new_observed_variables = Set {SymbolicT} ()
398+ while active && maxiters > 0 && candidate_eqs_count > 0
399+ # We've reached a fixpoint unless the inner loop adds an observed equation
400+ active = false
401+ for i in eachindex (algeqs)
402+ # Ignore if we're not considering this for elimination or it is already eliminated
403+ active_alg_eqs[i] || continue
404+ alg_eqs_mask[i] || continue
405+ eq = algeqs[i]
406+ candidate_var = eq. lhs
407+ # LHS must be an unknown and must not be another observed
408+ if ! (candidate_var in all_dvs) || candidate_var in new_observed_variables
409+ active_alg_eqs[i] = false
410+ candidate_eqs_count -= 1
411+ continue
412+ end
413+ # Remove newly added observed variables
414+ vars_in_algeq = vars_in_each_algeq[i]
415+ setdiff! (vars_in_algeq, new_observed_variables)
416+ # If the incidence is empty, it is a function of observed and diffvars
417+ isempty (vars_in_algeq) || continue
418+
419+ # We added an observed equation, so we haven't reached a fixpoint yet
420+ active = true
421+ push! (new_observed_variables, candidate_var)
422+ push! (obseqs, eq)
423+ # This is no longer considered for elimination
424+ active_alg_eqs[i] = false
425+ candidate_eqs_count -= 1
426+ # And is no longer algebraic
427+ alg_eqs_mask[i] = false
428+ end
429+ # Safeguard against infinite loops, because `while true` is potentially dangerous
430+ maxiters -= 1
431+ end
432+
433+ keepat! (algeqs, alg_eqs_mask)
434+ end
435+
436+ function offset_array (origin, arr)
437+ if all (isone, origin)
438+ return arr
439+ end
440+ return Origin (origin)(arr)
441+ end
442+
443+ @register_array_symbolic offset_array (origin:: Any , arr:: AbstractArray ) begin
444+ size = size (arr)
445+ eltype = eltype (arr)
446+ ndims = ndims (arr)
447+ end
448+
449+ function add_array_observed! (obseqs:: Vector{Equation} )
450+ array_obsvars = Set {SymbolicT} ()
451+ for eq in obseqs
452+ arr, isarr = split_indexed_var (eq. lhs)
453+ isarr && push! (array_obsvars, arr)
454+ end
455+ for var in array_obsvars
456+ firstind = first (SU. stable_eachindex (var)):: SU.StableIndex{Int}
457+ firstind = Tuple (firstind. idxs)
458+ scal = SymbolicT[]
459+ for i in SU. stable_eachindex (var)
460+ push! (scal, var[i])
461+ end
462+ push! (obseqs, var ~ offset_array (firstind, reshape (scal, size (var))))
463+ end
464+ end
465+
334466function simplify_sde_system (sys:: AbstractSystem ; kwargs... )
335467 brown_vars = brownians (sys)
336468 @set! sys. brownians = SymbolicT[]
0 commit comments