Skip to content

Commit 752be46

Browse files
authored
Merge pull request #1538 from SciML/myb/option
Add allow_symbolic and allow_parameter as options in structural_simplify
2 parents 076b2c5 + 2d9271d commit 752be46

File tree

5 files changed

+14
-12
lines changed

5 files changed

+14
-12
lines changed

src/structural_transformation/pantelides.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ instead, which calls this function internally.
139139
"""
140140
function dae_index_lowering(sys::ODESystem; kwargs...)
141141
state = TearingState(sys)
142-
find_solvables!(state)
143142
var_eq_matching = pantelides!(state; kwargs...)
144143
return invalidate_cache!(pantelides_reassemble(state, var_eq_matching))
145144
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function var_derivative!(ts::TearingState{ODESystem}, v::Int)
3333
sys = ts.sys
3434
s = ts.structure
3535
D = Differential(get_iv(sys))
36-
add_vertex!(s.solvable_graph, DST)
36+
s.solvable_graph === nothing || add_vertex!(s.solvable_graph, DST)
3737
push!(ts.fullvars, D(ts.fullvars[v]))
3838
end
3939

@@ -43,7 +43,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
4343
D = Differential(get_iv(sys))
4444
eq = equations(ts)[ieq]
4545
eq = ModelingToolkit.expand_derivatives(0 ~ D(eq.rhs - eq.lhs))
46-
add_vertex!(s.solvable_graph, SRC)
46+
s.solvable_graph === nothing || add_vertex!(s.solvable_graph, SRC)
4747
push!(equations(ts), eq)
4848
# Analyze the new equation and update the graph/solvable_graph
4949
# First, copy the previous incidence and add the derivative terms.
@@ -54,7 +54,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
5454
add_edge!(s.graph, eq_diff, var)
5555
add_edge!(s.graph, eq_diff, s.var_to_diff[var])
5656
end
57-
find_eq_solvables!(ts, eq_diff; may_be_zero=true, allow_symbolic=true)
57+
s.solvable_graph === nothing || find_eq_solvables!(ts, eq_diff; may_be_zero=true, allow_symbolic=true)
5858
end
5959

6060
function tearing_sub(expr, dict, s)
@@ -242,8 +242,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify=false
242242
return invalidate_cache!(sys)
243243
end
244244

245-
function tearing(state::TearingState)
246-
state.structure.solvable_graph === nothing && find_solvables!(state)
245+
function tearing(state::TearingState; kwargs...)
246+
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
247247
complete!(state.structure)
248248
@unpack graph, solvable_graph = state.structure
249249
algvars = BitSet(findall(v->isalgvar(state.structure, v), 1:ndsts(graph)))

src/structural_transformation/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ function find_eq_solvables!(state::TearingState, ieq; may_be_zero=false, allow_s
197197
end
198198
end
199199

200-
function find_solvables!(state::TearingState; allow_symbolic=false)
200+
function find_solvables!(state::TearingState; kwargs...)
201201
@assert state.structure.solvable_graph === nothing
202202
eqs = equations(state)
203203
graph = state.structure.graph
204204
state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
205205
for ieq in 1:length(eqs)
206-
find_eq_solvables!(state, ieq; allow_symbolic)
206+
find_eq_solvables!(state, ieq; kwargs...)
207207
end
208208
return nothing
209209
end

src/systems/abstractsystem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -919,9 +919,11 @@ $(SIGNATURES)
919919
920920
Structurally simplify algebraic equations in a system and compute the
921921
topological sort of the observed equations. When `simplify=true`, the `simplify`
922-
function will be applied during the tearing process.
922+
function will be applied during the tearing process. It also takes kwargs
923+
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
924+
types during tearing.
923925
"""
924-
function structural_simplify(sys::AbstractSystem; simplify=false)
926+
function structural_simplify(sys::AbstractSystem; simplify=false, kwargs...)
925927
sys = expand_connections(sys)
926928
sys = alias_elimination(sys)
927929
state = TearingState(sys)
@@ -930,7 +932,7 @@ function structural_simplify(sys::AbstractSystem; simplify=false)
930932
sys = dae_index_lowering(ode_order_lowering(sys))
931933
end
932934
state = TearingState(sys)
933-
find_solvables!(state)
935+
find_solvables!(state; kwargs...)
934936
sys = tearing_reassemble(state, tearing(state), simplify=simplify)
935937
fullstates = [map(eq->eq.lhs, observed(sys)); states(sys)]
936938
@set! sys.observed = topsort_equations(observed(sys), fullstates)

test/components.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ using ModelingToolkit.BipartiteGraphs
44
using ModelingToolkit.StructuralTransformations
55

66
function check_contract(sys)
7+
graph = ModelingToolkit.get_tearing_state(sys).structure.graph
78
sys = tearing_substitution(sys)
89
state = TearingState(sys)
910
fullvars = state.fullvars
10-
graph = state.structure.graph
1111

1212
eqs = equations(sys)
1313
var2idx = Dict(enumerate(fullvars))
@@ -30,6 +30,7 @@ end
3030

3131
include("../examples/rc_model.jl")
3232

33+
@test length(equations(structural_simplify(rc_model, allow_parameter=false))) > 1
3334
sys = structural_simplify(rc_model)
3435
check_contract(sys)
3536
@test !isempty(ModelingToolkit.defaults(sys))

0 commit comments

Comments
 (0)