Skip to content

Commit f959c57

Browse files
committed
Merge branch 'master' into myb/strong_alias
2 parents 7c026b0 + 644d76c commit f959c57

File tree

6 files changed

+151
-14
lines changed

6 files changed

+151
-14
lines changed

src/bipartite_graph.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,49 @@ end
191191
# Matrix whose only purpose is to pretty-print the bipartite graph
192192
struct BipartiteAdjacencyList
193193
u::Union{Vector{Int}, Nothing}
194+
highligh_u::Union{Set{Int}, Nothing}
195+
match::Union{Int, Unassigned}
194196
end
197+
function BipartiteAdjacencyList(u::Union{Vector{Int}, Nothing})
198+
BipartiteAdjacencyList(u, nothing, unassigned)
199+
end
200+
201+
struct HighlightInt
202+
i::Int
203+
highlight::Union{Symbol, Nothing}
204+
end
205+
Base.typeinfo_implicit(::Type{HighlightInt}) = true
206+
207+
function Base.show(io::IO, hi::HighlightInt)
208+
if hi.highlight !== nothing
209+
printstyled(io, hi.i, color = hi.highlight)
210+
else
211+
print(io, hi.i)
212+
end
213+
end
214+
195215
function Base.show(io::IO, l::BipartiteAdjacencyList)
196216
if l.u === nothing
197217
printstyled(io, '', color = :light_black)
198218
elseif isempty(l.u)
199219
printstyled(io, '', color = :light_black)
200-
else
220+
elseif l.highligh_u === nothing
201221
print(io, l.u)
222+
else
223+
function choose_color(i)
224+
i in l.highligh_u ? (i == l.match ? :light_yellow : :green) :
225+
(i == l.match ? :yellow : nothing)
226+
end
227+
if !isempty(setdiff(l.highligh_u, l.u))
228+
# Only for debugging, shouldn't happen in practice
229+
print(io, map(union(l.u, l.highligh_u)) do i
230+
HighlightInt(i, !(i in l.u) ? :light_red : choose_color(i))
231+
end)
232+
else
233+
print(io, map(l.u) do i
234+
HighlightInt(i, choose_color(i))
235+
end)
236+
end
202237
end
203238
end
204239

src/inputoutput.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
218218
f, dvs, ps
219219
end
220220

221-
function inputs_to_parameters!(state::TransformationState, check_bound = true)
221+
function inputs_to_parameters!(state::TransformationState, io)
222+
check_bound = io === nothing
222223
@unpack structure, fullvars, sys = state
223224
@unpack var_to_diff, graph, solvable_graph = structure
224225
@assert solvable_graph === nothing
@@ -274,6 +275,17 @@ function inputs_to_parameters!(state::TransformationState, check_bound = true)
274275
@set! sys.eqs = map(Base.Fix2(substitute, input_to_parameters), equations(sys))
275276
@set! sys.states = setdiff(states(sys), keys(input_to_parameters))
276277
ps = parameters(sys)
278+
279+
if io !== nothing
280+
# Change order of new parameters to correspond to user-provided order in argument `inputs`
281+
d = Dict{Any, Int}()
282+
for (i, inp) in enumerate(new_parameters)
283+
d[inp] = i
284+
end
285+
permutation = [d[i] for i in io.inputs]
286+
new_parameters = new_parameters[permutation]
287+
end
288+
277289
@set! sys.ps = [ps; new_parameters]
278290

279291
@set! state.sys = sys

src/systems/abstractsystem.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false
968968
state = TearingState(sys)
969969
has_io = io !== nothing
970970
has_io && markio!(state, io...)
971-
state, input_idxs = inputs_to_parameters!(state, !has_io)
971+
state, input_idxs = inputs_to_parameters!(state, io)
972972
sys = alias_elimination!(state)
973973
# TODO: avoid construct `TearingState` again.
974974
state = TearingState(sys)
@@ -984,7 +984,7 @@ end
984984

985985
function io_preprocessing(sys::AbstractSystem, inputs,
986986
outputs; simplify = false, kwargs...)
987-
sys, input_idxs = structural_simplify(sys, (inputs, outputs); simplify, kwargs...)
987+
sys, input_idxs = structural_simplify(sys, (; inputs, outputs); simplify, kwargs...)
988988

989989
eqs = equations(sys)
990990
alg_start_idx = findfirst(!isdiffeq, eqs)
@@ -1201,6 +1201,8 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
12011201
nz = size(f_z, 2)
12021202
ny = size(h_x, 1)
12031203

1204+
D = h_u
1205+
12041206
if isempty(g_z)
12051207
A = f_x
12061208
B = f_u
@@ -1216,20 +1218,20 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
12161218
A = [f_x f_z
12171219
gzgx*f_x gzgx*f_z]
12181220
B = [f_u
1219-
zeros(nz, nu)]
1221+
gzgx * f_u] # The cited paper has zeros in the bottom block, see derivation in https://github.com/SciML/ModelingToolkit.jl/pull/1691 for the correct formula
1222+
12201223
C = [h_x h_z]
12211224
Bs = -(gz \ g_u) # This equation differ from the cited paper, the paper is likely wrong since their equaiton leads to a dimension mismatch.
12221225
if !iszero(Bs)
12231226
if !allow_input_derivatives
12241227
der_inds = findall(vec(any(!=(0), Bs, dims = 1)))
1225-
error("Input derivatives appeared in expressions (-g_z\\g_u != 0), the following inputs appeared differentiated: $(inputs(sys)[der_inds]). Call `linear_staespace` with keyword argument `allow_input_derivatives = true` to allow this and have the returned `B` matrix be of double width ($(2nu)), where the last $nu inputs are the derivatives of the first $nu inputs.")
1228+
error("Input derivatives appeared in expressions (-g_z\\g_u != 0), the following inputs appeared differentiated: $(inputs(sys)[der_inds]). Call `linear_statespace` with keyword argument `allow_input_derivatives = true` to allow this and have the returned `B` matrix be of double width ($(2nu)), where the last $nu inputs are the derivatives of the first $nu inputs.")
12261229
end
1227-
B = [B Bs]
1230+
B = [B [zeros(nx, nu); Bs]]
1231+
D = [D zeros(ny, nu)]
12281232
end
12291233
end
12301234

1231-
D = h_u
1232-
12331235
(; A, B, C, D)
12341236
end
12351237

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ function build_explicit_observed_function(sys, ts;
316316
subs[s] = s′
317317
continue
318318
end
319-
throw(ArgumentError("$s is either an observed nor a state variable."))
319+
throw(ArgumentError("$s is neither an observed nor a state variable."))
320320
end
321321
continue
322322
end

src/systems/systemstructure.jl

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,93 @@ function linear_subsys_adjmat(state::TransformationState)
403403
linear_equations, eadj, cadj)
404404
end
405405

406+
using .BipartiteGraphs: Label, BipartiteAdjacencyList
407+
struct SystemStructurePrintMatrix <:
408+
AbstractMatrix{Union{Label, Int, BipartiteAdjacencyList}}
409+
bpg::BipartiteGraph
410+
highlight_graph::BipartiteGraph
411+
var_to_diff::DiffGraph
412+
eq_to_diff::DiffGraph
413+
var_eq_matching::Union{Matching, Nothing}
414+
end
415+
Base.size(bgpm::SystemStructurePrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.bpg)) + 1, 5)
416+
function compute_diff_label(diff_graph, i)
417+
di = i - 1 <= length(diff_graph) ? diff_graph[i - 1] : nothing
418+
ii = i - 1 <= length(invview(diff_graph)) ? invview(diff_graph)[i - 1] : nothing
419+
return Label(string(di === nothing ? "" : string(di, ''),
420+
di !== nothing && ii !== nothing ? " " : "",
421+
ii === nothing ? "" : string(ii, '')))
422+
end
423+
function Base.getindex(bgpm::SystemStructurePrintMatrix, i::Integer, j::Integer)
424+
checkbounds(bgpm, i, j)
425+
if i <= 1
426+
return (Label.(("#", "∂ₜ", "eq", "∂ₜ", "v")))[j]
427+
elseif j == 2
428+
return compute_diff_label(bgpm.eq_to_diff, i)
429+
elseif j == 4
430+
return compute_diff_label(bgpm.var_to_diff, i)
431+
elseif j == 1
432+
return i - 1
433+
elseif j == 3
434+
return BipartiteAdjacencyList(i - 1 <= nsrcs(bgpm.bpg) ?
435+
𝑠neighbors(bgpm.bpg, i - 1) : nothing,
436+
bgpm.highlight_graph !== nothing &&
437+
i - 1 <= nsrcs(bgpm.highlight_graph) ?
438+
Set(𝑠neighbors(bgpm.highlight_graph, i - 1)) :
439+
nothing,
440+
bgpm.var_eq_matching !== nothing &&
441+
(i - 1 <= length(invview(bgpm.var_eq_matching))) ?
442+
invview(bgpm.var_eq_matching)[i - 1] : unassigned)
443+
elseif j == 5
444+
return BipartiteAdjacencyList(i - 1 <= ndsts(bgpm.bpg) ?
445+
𝑑neighbors(bgpm.bpg, i - 1) : nothing,
446+
bgpm.highlight_graph !== nothing &&
447+
i - 1 <= ndsts(bgpm.highlight_graph) ?
448+
Set(𝑑neighbors(bgpm.highlight_graph, i - 1)) :
449+
nothing,
450+
bgpm.var_eq_matching !== nothing &&
451+
(i - 1 <= length(bgpm.var_eq_matching)) ?
452+
bgpm.var_eq_matching[i - 1] : unassigned)
453+
else
454+
@assert false
455+
end
456+
end
457+
406458
function Base.show(io::IO, mime::MIME"text/plain", s::SystemStructure)
407-
@unpack graph = s
408-
S = incidence_matrix(graph, Num(Sym{Real}(:×)))
409-
print(io, "Incidence matrix:")
410-
show(io, mime, S)
459+
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = s
460+
if !get(io, :limit, true) || !get(io, :mtk_limit, true)
461+
print(io, "SystemStructure with ", length(graph.fadjlist), " equations and ",
462+
isa(graph.badjlist, Int) ? graph.badjlist : length(graph.badjlist),
463+
" variables\n")
464+
Base.print_matrix(io,
465+
SystemStructurePrintMatrix(complete(graph),
466+
complete(solvable_graph),
467+
complete(var_to_diff),
468+
complete(eq_to_diff), nothing))
469+
else
470+
S = incidence_matrix(graph, Num(Sym{Real}(:×)))
471+
print(io, "Incidence matrix:")
472+
show(io, mime, S)
473+
end
474+
end
475+
476+
struct MatchedSystemStructure
477+
structure::SystemStructure
478+
var_eq_matching::Matching
479+
end
480+
481+
function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
482+
s = ms.structure
483+
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = s
484+
print(io, "Matched SystemStructure with ", length(graph.fadjlist), " equations and ",
485+
isa(graph.badjlist, Int) ? graph.badjlist : length(graph.badjlist),
486+
" variables\n")
487+
Base.print_matrix(io,
488+
SystemStructurePrintMatrix(complete(graph),
489+
complete(solvable_graph),
490+
complete(var_to_diff),
491+
complete(eq_to_diff),
492+
complete(ms.var_eq_matching, nsrcs(graph))))
411493
end
412494

413495
end # module

test/structural_transformation/tearing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ if VERSION >= v"1.6"
3030
@test occursin("Incidence matrix:", prt)
3131
@test occursin("×", prt)
3232
@test occursin("", prt)
33+
34+
buff = IOBuffer()
35+
io = IOContext(buff, :mtk_limit => false)
36+
show(io, MIME"text/plain"(), state.structure)
37+
prt = String(take!(buff))
38+
@test occursin("SystemStructure", prt)
3339
end
3440

3541
# u1 = f1(u5)

0 commit comments

Comments
 (0)