Skip to content

Commit f6452d6

Browse files
committed
Add allow_symbolic and allow_parameter as options in structural_simplify
1 parent 076b2c5 commit f6452d6

File tree

5 files changed

+11
-9
lines changed

5 files changed

+11
-9
lines changed

src/structural_transformation/pantelides.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ instead, which calls this function internally.
139139
"""
140140
function dae_index_lowering(sys::ODESystem; kwargs...)
141141
state = TearingState(sys)
142-
find_solvables!(state)
143142
var_eq_matching = pantelides!(state; kwargs...)
144143
return invalidate_cache!(pantelides_reassemble(state, var_eq_matching))
145144
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify=false
242242
return invalidate_cache!(sys)
243243
end
244244

245-
function tearing(state::TearingState)
246-
state.structure.solvable_graph === nothing && find_solvables!(state)
245+
function tearing(state::TearingState; kwargs...)
246+
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
247247
complete!(state.structure)
248248
@unpack graph, solvable_graph = state.structure
249249
algvars = BitSet(findall(v->isalgvar(state.structure, v), 1:ndsts(graph)))

src/structural_transformation/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ function find_eq_solvables!(state::TearingState, ieq; may_be_zero=false, allow_s
197197
end
198198
end
199199

200-
function find_solvables!(state::TearingState; allow_symbolic=false)
200+
function find_solvables!(state::TearingState; kwargs...)
201201
@assert state.structure.solvable_graph === nothing
202202
eqs = equations(state)
203203
graph = state.structure.graph
204204
state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
205205
for ieq in 1:length(eqs)
206-
find_eq_solvables!(state, ieq; allow_symbolic)
206+
find_eq_solvables!(state, ieq; kwargs...)
207207
end
208208
return nothing
209209
end

src/systems/abstractsystem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -919,9 +919,11 @@ $(SIGNATURES)
919919
920920
Structurally simplify algebraic equations in a system and compute the
921921
topological sort of the observed equations. When `simplify=true`, the `simplify`
922-
function will be applied during the tearing process.
922+
function will be applied during the tearing process. It also takes kwargs
923+
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
924+
types during tearing.
923925
"""
924-
function structural_simplify(sys::AbstractSystem; simplify=false)
926+
function structural_simplify(sys::AbstractSystem; simplify=false, kwargs...)
925927
sys = expand_connections(sys)
926928
sys = alias_elimination(sys)
927929
state = TearingState(sys)
@@ -930,7 +932,7 @@ function structural_simplify(sys::AbstractSystem; simplify=false)
930932
sys = dae_index_lowering(ode_order_lowering(sys))
931933
end
932934
state = TearingState(sys)
933-
find_solvables!(state)
935+
find_solvables!(state; kwargs...)
934936
sys = tearing_reassemble(state, tearing(state), simplify=simplify)
935937
fullstates = [map(eq->eq.lhs, observed(sys)); states(sys)]
936938
@set! sys.observed = topsort_equations(observed(sys), fullstates)

test/components.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ using ModelingToolkit.BipartiteGraphs
44
using ModelingToolkit.StructuralTransformations
55

66
function check_contract(sys)
7+
graph = ModelingToolkit.get_tearing_state(sys).structure.graph
78
sys = tearing_substitution(sys)
89
state = TearingState(sys)
910
fullvars = state.fullvars
10-
graph = state.structure.graph
1111

1212
eqs = equations(sys)
1313
var2idx = Dict(enumerate(fullvars))
@@ -30,6 +30,7 @@ end
3030

3131
include("../examples/rc_model.jl")
3232

33+
@test length(equations(structural_simplify(rc_model, allow_parameter=false))) > 1
3334
sys = structural_simplify(rc_model)
3435
check_contract(sys)
3536
@test !isempty(ModelingToolkit.defaults(sys))

0 commit comments

Comments
 (0)