|
| 1 | +using SparseArrays |
| 2 | + |
| 3 | +const SHOW_EQUATIONS = Ref(false) |
| 4 | +struct SystemStructure |
| 5 | + sys |
| 6 | + dxvar_offset::Int |
| 7 | + fullvars::Vector # [xvar; dxvars; algvars] |
| 8 | + varassoc::Vector{Int} |
| 9 | + graph::BipartiteGraph{Int} |
| 10 | + solvable_graph::BipartiteGraph{Int} |
| 11 | +end |
| 12 | +function SystemStructure(sys) |
| 13 | + sys = ModelingToolkit.flatten(sys) |
| 14 | + sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph = init_graph(sys) |
| 15 | + SystemStructure(sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph) |
| 16 | +end |
| 17 | + |
| 18 | +ModelingToolkit.equations(s::SystemStructure) = equations(s.sys) |
| 19 | + |
| 20 | +function Base.show(io::IO, s::SystemStructure) |
| 21 | + @unpack fullvars, dxvar_offset, solvable_graph, graph = s |
| 22 | + algvar_offset = 2dxvar_offset |
| 23 | + print(io, "xvars: ") |
| 24 | + print(io, fullvars[1:dxvar_offset]) |
| 25 | + print(io, "\ndxvars: ") |
| 26 | + print(io, fullvars[dxvar_offset+1:algvar_offset]) |
| 27 | + print(io, "\nalgvars: ") |
| 28 | + print(io, fullvars[algvar_offset+1:end], '\n') |
| 29 | + |
| 30 | + if SHOW_EQUATIONS[] |
| 31 | + println(io, "Edges:") |
| 32 | + eqs = equations(s) |
| 33 | + for ev in 𝑠vertices(graph) |
| 34 | + print(io, " $(eqs[ev])\n -> ") |
| 35 | + vars = 𝑠neighbors(graph, ev) |
| 36 | + solvars = 𝑠neighbors(solvable_graph, ev) |
| 37 | + solvable = intersect(vars, solvars) |
| 38 | + notsolvable = setdiff(vars, solvars) |
| 39 | + |
| 40 | + print(io, join(string.(fullvars[notsolvable]), ", ")) |
| 41 | + for ii in solvable |
| 42 | + print(io, ", ") |
| 43 | + var = fullvars[ii] |
| 44 | + Base.printstyled(io, string(fullvars[ii]), color=:cyan) |
| 45 | + end |
| 46 | + println(io) |
| 47 | + end |
| 48 | + end |
| 49 | + |
| 50 | + S = incidence_matrix(graph, Num(Sym{Real}(:×))) |
| 51 | + print(io, "Incidence matrix:") |
| 52 | + show(io, S) |
| 53 | +end |
| 54 | + |
| 55 | +# V-nodes `[x_1, x_2, x_3, ..., dx_1, dx_2, ..., y_1, y_2, ...]` where `x`s are |
| 56 | +# differential variables and `y`s are algebraic variables. |
| 57 | +function get_vnodes(sys) |
| 58 | + dxvars = [] |
| 59 | + eqs = equations(sys) |
| 60 | + for (i, eq) in enumerate(eqs) |
| 61 | + if eq.lhs isa Symbolic |
| 62 | + # Make sure that the LHS is a first order derivative of a var. |
| 63 | + @assert operation(eq.lhs) isa Differential "The equation $eq is not in the form of `D(...) ~ ...`" |
| 64 | + @assert !(arguments(eq.lhs)[1] isa Differential) "The equation $eq is not first order" |
| 65 | + |
| 66 | + push!(dxvars, eq.lhs) |
| 67 | + end |
| 68 | + end |
| 69 | + |
| 70 | + xvars = (first ∘ var_from_nested_derivative).(dxvars) |
| 71 | + algvars = setdiff(states(sys), xvars) |
| 72 | + return xvars, dxvars, algvars |
| 73 | +end |
| 74 | + |
| 75 | +function init_graph(sys) |
| 76 | + xvars, dxvars, algvars = get_vnodes(sys) |
| 77 | + dxvar_offset = length(xvars) |
| 78 | + algvar_offset = 2dxvar_offset |
| 79 | + |
| 80 | + fullvars = [xvars; dxvars; algvars] |
| 81 | + sys = reordersys(sys, dxvar_offset, fullvars) |
| 82 | + eqs = equations(sys) |
| 83 | + idxmap = Dict(fullvars .=> 1:length(fullvars)) |
| 84 | + graph = BipartiteGraph(length(eqs), length(fullvars)) |
| 85 | + solvable_graph = BipartiteGraph(length(eqs), length(fullvars)) |
| 86 | + |
| 87 | + for (i, eq) in enumerate(eqs) |
| 88 | + if isdiffeq(eq) |
| 89 | + v = eq.lhs |
| 90 | + haskey(idxmap, v) && add_edge!(graph, i, idxmap[v]) |
| 91 | + end |
| 92 | + # TODO: custom vars that handles D(x) |
| 93 | + vs = vars(eq.rhs) |
| 94 | + for v in vs |
| 95 | + haskey(idxmap, v) && add_edge!(graph, i, idxmap[v]) |
| 96 | + end |
| 97 | + end |
| 98 | + |
| 99 | + varassoc = Int[(1:dxvar_offset) .+ dxvar_offset; zeros(Int, length(fullvars) - dxvar_offset)] # variable association list |
| 100 | + sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph |
| 101 | +end |
| 102 | + |
| 103 | +function reordersys(sys, dxvar_offset, fullvars) |
| 104 | + eqs = equations(sys) |
| 105 | + neweqs = Vector{Equation}(undef, length(eqs)) |
| 106 | + eqidxmap = Dict(@view(fullvars[dxvar_offset+1:2dxvar_offset]) .=> (1:dxvar_offset)) |
| 107 | + varidxmap = Dict([@view(fullvars[1:dxvar_offset]); @view(fullvars[2dxvar_offset+1:end])] .=> (1:length(fullvars)-dxvar_offset)) |
| 108 | + algidx = dxvar_offset |
| 109 | + for eq in eqs |
| 110 | + if isdiffeq(eq) |
| 111 | + neweqs[eqidxmap[eq.lhs]] = eq |
| 112 | + else |
| 113 | + neweqs[algidx+=1] = eq |
| 114 | + end |
| 115 | + end |
| 116 | + sts = states(sys) |
| 117 | + @set! sys.eqs = neweqs |
| 118 | + @set! sys.states = sts[map(s->varidxmap[s], sts)] |
| 119 | +end |
0 commit comments