Skip to content

Commit 48ec55c

Browse files
feat: add map_variables_to_equations
1 parent 3d9a8d8 commit 48ec55c

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ export TearingState
276276
export BipartiteGraph, equation_dependencies, variable_dependencies
277277
export eqeq_dependencies, varvar_dependencies
278278
export asgraph, asdigraph
279+
export map_variables_to_equations
279280

280281
export toexpr, get_variables
281282
export simplify, substitute

src/systems/systems.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,57 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
158158
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
159159
end
160160
end
161+
162+
"""
163+
$(TYPEDSIGNATURES)
164+
165+
Given a system that has been simplified via `structural_simplify`, return a `Dict` mapping
166+
variables of the system to equations that are used to solve for them. This includes
167+
observed variables.
168+
169+
# Keyword Arguments
170+
171+
- `rename_dummy_derivatives`: Whether to rename dummy derivative variable keys into their
172+
`Differential` forms. For example, this would turn the key `yˍt(t)` into
173+
`Differential(t)(y(t))`.
174+
"""
175+
function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivatives = true)
176+
if !has_tearing_state(sys)
177+
throw(ArgumentError("$(typeof(sys)) is not supported."))
178+
end
179+
ts = get_tearing_state(sys)
180+
if ts === nothing
181+
throw(ArgumentError("`map_variables_to_equations` requires a simplified system. Call `structural_simplify` on the system before calling this function."))
182+
end
183+
184+
dummy_sub = Dict()
185+
if rename_dummy_derivatives && has_schedule(sys) && (sc = get_schedule(sys)) !== nothing
186+
dummy_sub = Dict(v => k for (k, v) in sc.dummy_sub if isequal(default_toterm(k), v))
187+
end
188+
189+
mapping = Dict{Union{Num, BasicSymbolic}, Equation}()
190+
eqs = equations(sys)
191+
for eq in eqs
192+
isdifferential(eq.lhs) || continue
193+
var = arguments(eq.lhs)[1]
194+
var = get(dummy_sub, var, var)
195+
mapping[var] = eq
196+
end
197+
198+
graph = ts.structure.graph
199+
algvars = BitSet(findall(
200+
Base.Fix1(StructuralTransformations.isalgvar, ts.structure), 1:ndsts(graph)))
201+
algeqs = BitSet(findall(1:nsrcs(graph)) do eq
202+
all(!Base.Fix1(isdervar, ts.structure), 𝑠neighbors(graph, eq))
203+
end)
204+
alge_var_eq_matching = complete(maximal_matching(graph, in(algeqs), in(algvars)))
205+
for (i, eq) in enumerate(alge_var_eq_matching)
206+
eq isa Unassigned && continue
207+
mapping[get(dummy_sub, ts.fullvars[i], ts.fullvars[i])] = eqs[eq]
208+
end
209+
for eq in observed(sys)
210+
mapping[get(dummy_sub, eq.lhs, eq.lhs)] = eq
211+
end
212+
213+
return mapping
214+
end

0 commit comments

Comments
 (0)