Skip to content

Commit 43fa26f

Browse files
authored
Merge pull request #2175 from SciML/myb/optsys
Fix `structural_transformation` for optimization systems
2 parents e1a9a06 + ab9b772 commit 43fa26f

File tree

6 files changed

+38
-21
lines changed

6 files changed

+38
-21
lines changed

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
7373
# find them here [TODO: It would be good to have an explicit example of this.]
7474

7575
@unpack graph, solvable_graph = structure
76-
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter, U))
76+
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
77+
var_eq_matching = complete(var_eq_matching,
78+
max(length(var_eq_matching),
79+
maximum(x -> x isa Int ? x : 0, var_eq_matching)))
7780
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
7881
vargraph = DiCMOBiGraph{true}(graph)
7982
ict = IncrementalCycleTracker(vargraph; dir = :in)

src/structural_transformation/symbolics_tearing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,10 +583,10 @@ Tear the nonlinear equations in system. When `simplify=true`, we simplify the
583583
new residual equations after tearing. End users are encouraged to call [`structural_simplify`](@ref)
584584
instead, which calls this function internally.
585585
"""
586-
function tearing(sys::AbstractSystem; simplify = false)
587-
state = TearingState(sys)
586+
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
587+
simplify = false, kwargs...)
588588
var_eq_matching = tearing(state)
589-
invalidate_cache!(tearing_reassemble(state, var_eq_matching; simplify = simplify))
589+
invalidate_cache!(tearing_reassemble(state, var_eq_matching; mm, simplify))
590590
end
591591

592592
"""

src/systems/alias_elimination.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,10 @@ function observed2graph(eqs, states)
468468
end
469469

470470
function fixpoint_sub(x, dict)
471-
y = substitute(x, dict)
471+
y = fast_substitute(x, dict)
472472
while !isequal(x, y)
473473
y = x
474-
x = substitute(y, dict)
474+
x = fast_substitute(y, dict)
475475
end
476476

477477
return x

src/systems/optimization/optimizationsystem.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -609,18 +609,19 @@ function structural_simplify(sys::OptimizationSystem; kwargs...)
609609
end
610610
end
611611
nlsys = NonlinearSystem(econs, states(sys), parameters(sys); name = :___tmp_nlsystem)
612-
snlsys = structural_simplify(nlsys; check_consistency = false, kwargs...)
612+
snlsys = structural_simplify(nlsys; fully_determined = false, kwargs...)
613613
obs = observed(snlsys)
614614
subs = Dict(eq.lhs => eq.rhs for eq in observed(snlsys))
615615
seqs = equations(snlsys)
616-
cons_simplified = Array{eltype(cons), 1}(undef, length(icons) + length(seqs))
616+
cons_simplified = similar(cons, length(icons) + length(seqs))
617617
for (i, eq) in enumerate(Iterators.flatten((seqs, icons)))
618-
cons_simplified[i] = substitute(eq, subs)
618+
cons_simplified[i] = fixpoint_sub(eq, subs)
619619
end
620620
newsts = setdiff(states(sys), keys(subs))
621621
@set! sys.constraints = cons_simplified
622622
@set! sys.observed = [observed(sys); obs]
623-
@set! sys.op = substitute(equations(sys), subs)
623+
neweqs = fixpoint_sub.(equations(sys), (subs,))
624+
@set! sys.op = length(neweqs) == 1 ? first(neweqs) : neweqs
624625
@set! sys.states = newsts
625626
return sys
626627
end

src/systems/systemstructure.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,15 @@ function merge_io(io, inputs)
554554
end
555555

556556
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
557-
check_consistency = true, kwargs...)
557+
check_consistency = true, fully_determined = true,
558+
kwargs...)
558559
if state.sys isa ODESystem
559560
ci = ModelingToolkit.ClockInference(state)
560561
ModelingToolkit.infer_clocks!(ci)
561562
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
562563
cont_io = merge_io(io, inputs[continuous_id])
563564
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
564-
check_consistency,
565+
check_consistency, fully_determined,
565566
kwargs...)
566567
if length(tss) > 1
567568
# TODO: rename it to something else
@@ -576,7 +577,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
576577
end
577578
dist_io = merge_io(io, inputs[i])
578579
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
579-
kwargs...)
580+
fully_determined, kwargs...)
580581
append!(appended_parameters, inputs[i], states(ss))
581582
discrete_subsystems[i] = ss
582583
end
@@ -588,14 +589,16 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
588589
end
589590
else
590591
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,
591-
kwargs...)
592+
fully_determined, kwargs...)
592593
end
593594
has_io = io !== nothing
594595
return has_io ? (sys, input_idxs) : sys
595596
end
596597

597598
function _structural_simplify!(state::TearingState, io; simplify = false,
598-
check_consistency = true, kwargs...)
599+
check_consistency = true, fully_determined = true,
600+
kwargs...)
601+
check_consistency &= fully_determined
599602
has_io = io !== nothing
600603
orig_inputs = Set()
601604
if has_io
@@ -606,7 +609,11 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
606609
if check_consistency
607610
ModelingToolkit.check_consistency(state, orig_inputs)
608611
end
609-
sys = ModelingToolkit.dummy_derivative(sys, state; simplify, mm)
612+
if fully_determined
613+
sys = ModelingToolkit.dummy_derivative(sys, state; simplify, mm, check_consistency)
614+
else
615+
sys = ModelingToolkit.tearing(sys, state; simplify, mm, check_consistency)
616+
end
610617
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
611618
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
612619
ModelingToolkit.invalidate_cache!(sys), input_idxs

src/utils.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -845,13 +845,19 @@ end
845845

846846
# Symbolics needs to call unwrap on the substitution rules, but most of the time
847847
# we don't want to do that in MTK.
848-
function fast_substitute(eq::Equation, subs)
849-
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
848+
const Eq = Union{Equation, Inequality}
849+
function fast_substitute(eq::Eq, subs)
850+
if eq isa Inequality
851+
Inequality(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs),
852+
eq.relational_op)
853+
else
854+
Equation(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
855+
end
850856
end
851-
function fast_substitute(eq::Equation, subs::Pair)
852-
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
857+
function fast_substitute(eq::T, subs::Pair) where {T <: Eq}
858+
T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
853859
end
854-
fast_substitute(eqs::AbstractArray{Equation}, subs) = fast_substitute.(eqs, (subs,))
860+
fast_substitute(eqs::AbstractArray{<:Eq}, subs) = fast_substitute.(eqs, (subs,))
855861
fast_substitute(a, b) = substitute(a, b)
856862
function fast_substitute(expr, pair::Pair)
857863
a, b = pair

0 commit comments

Comments
 (0)