Skip to content

Commit 4c31ede

Browse files
committed
refactor simplification functions
to avoid code duplication
1 parent bb0fa0a commit 4c31ede

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

src/systems/abstractsystem.jl

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -948,10 +948,17 @@ function will be applied during the tearing process. It also takes kwargs
948948
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
949949
types during tearing.
950950
"""
951-
function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
951+
function structural_simplify(sys::AbstractSystem, args...; kwargs...)
952952
sys = expand_connections(sys)
953953
state = TearingState(sys)
954-
state, = inputs_to_parameters!(state)
954+
sys, input_idxs = _structural_simplify(sys, state, args...; kwargs...)
955+
sys
956+
end
957+
958+
function _structural_simplify(sys::AbstractSystem, state; simplify = false,
959+
check_bound = true,
960+
kwargs...)
961+
state, input_idxs = inputs_to_parameters!(state, check_bound)
955962
sys = alias_elimination!(state)
956963
state = TearingState(sys)
957964
check_consistency(state)
@@ -964,7 +971,31 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
964971
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
965972
@set! sys.observed = topsort_equations(observed(sys), fullstates)
966973
invalidate_cache!(sys)
967-
return sys
974+
return sys, input_idxs
975+
end
976+
977+
function io_preprocessing(sys::AbstractSystem, inputs,
978+
outputs; simplify = false, kwargs...)
979+
sys = expand_connections(sys)
980+
state = TearingState(sys)
981+
markio!(state, inputs, outputs)
982+
sys, input_idxs = _structural_simplify(sys, state; simplify, check_bound = false,
983+
kwargs...)
984+
985+
eqs = equations(sys)
986+
check_operator_variables(eqs, Differential)
987+
# Sort equations and states such that diff.eqs. match differential states and the rest are algebraic
988+
diffstates = collect_operator_variables(sys, Differential)
989+
eqs = sort(eqs, by = e -> !isoperator(e.lhs, Differential),
990+
alg = Base.Sort.DEFAULT_STABLE)
991+
@set! sys.eqs = eqs
992+
diffstates = [arguments(e.lhs)[1] for e in eqs[1:length(diffstates)]]
993+
sts = [diffstates; setdiff(states(sys), diffstates)]
994+
@set! sys.states = sts
995+
diff_idxs = 1:length(diffstates)
996+
alge_idxs = (length(diffstates) + 1):length(sts)
997+
998+
sys, diff_idxs, alge_idxs, input_idxs
968999
end
9691000

9701001
"""
@@ -994,36 +1025,9 @@ See also [`linearize`](@ref) which provides a higher-level interface.
9941025
function linearization_function(sys::AbstractSystem, inputs,
9951026
outputs; simplify = false,
9961027
kwargs...)
997-
sys = expand_connections(sys)
998-
state = TearingState(sys)
999-
markio!(state, inputs, outputs)
1000-
state, input_idxs = inputs_to_parameters!(state, false)
1001-
sys = alias_elimination!(state)
1002-
state = TearingState(sys)
1003-
check_consistency(state)
1004-
if sys isa ODESystem
1005-
sys = dae_order_lowering(dummy_derivative(sys, state))
1006-
end
1007-
state = TearingState(sys)
1008-
find_solvables!(state; kwargs...)
1009-
sys = tearing_reassemble(state, tearing(state), simplify = simplify)
1010-
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
1011-
@set! sys.observed = topsort_equations(observed(sys), fullstates)
1012-
invalidate_cache!(sys)
1013-
1014-
eqs = equations(sys)
1015-
check_operator_variables(eqs, Differential)
1016-
# Sort equations and states such that diff.eqs. match differential states and the rest are algebraic
1017-
diffstates = collect_operator_variables(sys, Differential)
1018-
eqs = sort(eqs, by = e -> !isoperator(e.lhs, Differential),
1019-
alg = Base.Sort.DEFAULT_STABLE)
1020-
@set! sys.eqs = eqs
1021-
diffstates = [arguments(e.lhs)[1] for e in eqs[1:length(diffstates)]]
1022-
sts = [diffstates; setdiff(states(sys), diffstates)]
1023-
@set! sys.states = sts
1024-
1025-
diff_idxs = 1:length(diffstates)
1026-
alge_idxs = (length(diffstates) + 1):length(sts)
1028+
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
1029+
kwargs...)
1030+
sts = states(sys)
10271031
fun = ODEFunction(sys)
10281032
lin_fun = let fun = fun,
10291033
h = ModelingToolkit.build_explicit_observed_function(sys, outputs)

0 commit comments

Comments
 (0)