Skip to content

Commit 71c9c66

Browse files
feat: allow retaining user-provided observed in mtkcompile under conditions
Observed are only retained if they are not involved in the equations of the system. Exceptions to this rule error.
1 parent 9cdbcf3 commit 71c9c66

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ function update_simplified_system!(
960960
obs_sub[eq.lhs] = eq.rhs
961961
end
962962
# TODO: compute the dependency correctly so that we don't have to do this
963-
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
963+
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs; state.additional_observed]
964964

965965
unknown_idxs = filter(
966966
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))

src/systems/systemstructure.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
208208
structure::SystemStructure
209209
extra_eqs::Vector
210210
param_derivative_map::Dict{BasicSymbolic, Any}
211+
"""
212+
Additional user-provided observed equations. The variables calculated here
213+
are not used in the rest of the system.
214+
"""
215+
additional_observed::Vector{Equation}
211216
end
212217

213218
TransformationState(sys::AbstractSystem) = TearingState(sys)
@@ -277,6 +282,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
277282
# flatten array equations
278283
eqs = flatten_equations(equations(sys))
279284
neqs = length(eqs)
285+
obseqs = observed(sys)
286+
obsvars = Set([eq.lhs for eq in obseqs])
287+
@set! sys.observed = Equation[]
280288
param_derivative_map = Dict{BasicSymbolic, Any}()
281289
# * Scalarize unknowns
282290
dvs = Set{BasicSymbolic}()
@@ -320,6 +328,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
320328
varsbuf = Set()
321329
eqs_to_retain = trues(length(eqs))
322330
for (i, eq) in enumerate(eqs)
331+
_eq = eq
323332
if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential &&
324333
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv)
325334
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
@@ -347,6 +356,14 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
347356
push!(incidence, v)
348357
end
349358

359+
if v in obsvars
360+
error("""
361+
Observed equations in unsimplified systems cannot compute quantities that \
362+
are involved in the equations of the system. Encountered observed \
363+
variable `$v` in equation `$_eq`.
364+
""")
365+
end
366+
350367
# TODO: Can we handle this without `isparameter`?
351368
if symbolic_contains(v, ps) ||
352369
getmetadata(v, SymScope, LocalScope()) isa GlobalScope && isparameter(v)
@@ -516,7 +533,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
516533
ts = TearingState(sys, fullvars,
517534
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
518535
complete(graph), nothing, var_types, false),
519-
Any[], param_derivative_map)
536+
Any[], param_derivative_map, obseqs)
520537

521538
return ts
522539
end

0 commit comments

Comments
 (0)