Skip to content

Commit b530595

Browse files
committed
working simplification
1 parent f5e5aff commit b530595

File tree

3 files changed

+88
-18
lines changed

3 files changed

+88
-18
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ end
240240
function tearing_reassemble(state::TearingState, var_eq_matching,
241241
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
242242
@unpack fullvars, sys, structure = state
243-
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
243+
@unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
244244
extra_vars = Int[]
245245
if full_var_eq_matching !== nothing
246246
for v in 𝑑vertices(state.structure.graph)
@@ -279,6 +279,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
279279
iv = D = nothing
280280
end
281281
diff_to_var = invview(var_to_diff)
282+
282283
dummy_sub = Dict()
283284
for var in 1:length(fullvars)
284285
dv = var_to_diff[var]
@@ -310,7 +311,10 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
310311
diff_to_var[dv] = nothing
311312
end
312313
end
314+
@show neweqs
313315

316+
println("Post state selection.")
317+
314318
# `SelectedState` information is no longer needed past here. State selection
315319
# is done. All non-differentiated variables are algebraic variables, and all
316320
# variables that appear differentiated are differential variables.
@@ -331,10 +335,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
331335
order += 1
332336
dv = dv′
333337
end
338+
println("Order")
339+
@show fullvars[dv]
340+
is_only_discrete(state.structure) && begin
341+
var = fullvars[dv]
342+
key = operation(var) isa Shift ? only(arguments(var)) : var
343+
order = -get(lowest_shift, key, 0) - order
344+
end
334345
order, dv
335346
end
336347
end
337348

349+
lower_name = is_only_discrete(state.structure) ? lower_varname_withshift : lower_varname_with_unit
350+
# is_only_discrete(state.structure) && for v in 1:length(fullvars)
351+
# var = fullvars[v]
352+
# op = operation(var)
353+
# if op isa Shift
354+
# x = only(arguments(var))
355+
# lowest_shift_idxs[v]
356+
# op.steps == lowest_shift[x] && (fullvars[v] = lower_varname_withshift(var, iv, -op.steps))
357+
# end
358+
# end
359+
338360
#retear = BitSet()
339361
# There are three cases where we want to generate new variables to convert
340362
# the system into first order (semi-implicit) ODEs.
@@ -384,9 +406,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
384406
eq_var_matching = invview(var_eq_matching)
385407
linear_eqs = mm === nothing ? Dict{Int, Int}() :
386408
Dict(reverse(en) for en in enumerate(mm.nzrows))
409+
387410
for v in 1:length(var_to_diff)
388-
dv = var_to_diff[v]
411+
println()
412+
@show fullvars
413+
@show diff_to_var
414+
is_highest_discrete = begin
415+
var = fullvars[v]
416+
op = operation(var)
417+
if (!is_only_discrete(state.structure) || op isa Shift)
418+
false
419+
elseif !haskey(lowest_shift, var)
420+
false
421+
else
422+
low = lowest_shift[var]
423+
idx = findfirst(x -> isequal(x, Shift(iv, low)(var)), fullvars)
424+
true
425+
end
426+
end
427+
dv = is_highest_discrete ? idx : var_to_diff[v]
428+
@show (v, fullvars[v], dv)
389429
dv isa Int || continue
430+
390431
solved = var_eq_matching[dv] isa Int
391432
solved && continue
392433
# check if there's `D(x) = x_t` already
@@ -404,17 +445,19 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
404445
diff_to_var[v_t] === nothing)
405446
@assert dv in rvs
406447
dummy_eq = eq
448+
@show "FOUND DUMMY EQ"
407449
@goto FOUND_DUMMY_EQ
408450
end
409451
end
410452
dx = fullvars[dv]
411453
# add `x_t`
412-
order, lv = var_order(dv)
413-
x_t = lower_varname_withshift(fullvars[lv], iv, order)
454+
@show order, lv = var_order(dv)
455+
x_t = lower_name(fullvars[lv], iv, order)
414456
push!(fullvars, simplify_shifts(x_t))
415457
v_t = length(fullvars)
416458
v_t_idx = add_vertex!(var_to_diff)
417459
add_vertex!(graph, DST)
460+
@show x_t, dx
418461
# TODO: do we care about solvable_graph? We don't use them after
419462
# `dummy_derivative_graph`.
420463
add_vertex!(solvable_graph, DST)
@@ -433,10 +476,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
433476
add_edge!(solvable_graph, dummy_eq, dv)
434477
@assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
435478
@label FOUND_DUMMY_EQ
479+
@show is_highest_discrete
480+
@show diff_to_var
481+
@show v_t, dv
482+
# If var = x with no shift, then
483+
is_highest_discrete && (lowest_shift[x_t] = lowest_shift[fullvars[v]])
436484
var_to_diff[v_t] = var_to_diff[dv]
437485
var_eq_matching[dv] = unassigned
438486
eq_var_matching[dummy_eq] = dv
439487
end
488+
@show neweqs
440489

441490
# Will reorder equations and unknowns to be:
442491
# [diffeqs; ...]
@@ -537,6 +586,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
537586

538587
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
539588
for i in 1:length(solved_equations)]
589+
540590
# Contract the vertices in the structure graph to make the structure match
541591
# the new reality of the system we've just created.
542592
graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,

src/structural_transformation/utils.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -451,18 +451,22 @@ end
451451

452452
function lower_varname_withshift(var, iv, order)
453453
order == 0 && return var
454+
ds = "$iv-$order"
455+
d_separator = 'ˍ'
456+
454457
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
455458
O = only(arguments(var))
456459
oldop = operation(O)
457-
ds = "$iv-$order"
458-
d_separator = 'ˍ'
459460
newname = Symbol(string(nameof(oldop)), d_separator, ds)
460-
461-
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
462-
setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
463-
return ModelingToolkit._with_unit(identity, newvar, iv)
461+
else
462+
O = var
463+
oldop = operation(var)
464+
varname = split(string(nameof(oldop)), d_separator)[1]
465+
newname = Symbol(varname, d_separator, ds)
464466
end
465-
return lower_varname_with_unit(var, iv, order)
467+
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
468+
setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
469+
return ModelingToolkit._with_unit(identity, newvar, iv)
466470
end
467471

468472
function isdoubleshift(var)

src/systems/systemstructure.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,21 @@ get_fullvars(ts::TransformationState) = ts.fullvars
140140
has_equations(::TransformationState) = true
141141

142142
Base.@kwdef mutable struct SystemStructure
143-
# Maps the (index of) a variable to the (index of) the variable describing
144-
# its derivative.
143+
"""Maps the (index of) a variable to the (index of) the variable describing its derivative."""
145144
var_to_diff::DiffGraph
145+
"""Maps the (index of) a """
146146
eq_to_diff::DiffGraph
147147
# Can be access as
148148
# `graph` to automatically look at the bipartite graph
149149
# or as `torn` to assert that tearing has run.
150+
"""Incidence graph of the system of equations. An edge from equation x to variable y exists if variable y appears in equation x."""
150151
graph::BipartiteGraph{Int, Nothing}
152+
"""."""
151153
solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
152154
var_types::Union{Vector{VariableType}, Nothing}
155+
"""Whether the system is discrete."""
153156
only_discrete::Bool
157+
lowest_shift::Union{Dict, Nothing}
154158
end
155159

156160
function Base.copy(structure::SystemStructure)
@@ -346,6 +350,8 @@ function TearingState(sys; quick_cancel = false, check = true)
346350
eqs[i] = eqs[i].lhs ~ rhs
347351
end
348352
end
353+
354+
### Handle discrete variables
349355
lowest_shift = Dict()
350356
for var in fullvars
351357
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
@@ -430,10 +436,10 @@ function TearingState(sys; quick_cancel = false, check = true)
430436

431437
ts = TearingState(sys, fullvars,
432438
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
433-
complete(graph), nothing, var_types, sys isa DiscreteSystem),
439+
complete(graph), nothing, var_types, sys isa DiscreteSystem, lowest_shift),
434440
Any[])
435441
if sys isa DiscreteSystem
436-
ts = shift_discrete_system(ts)
442+
ts = shift_discrete_system(ts, lowest_shift)
437443
end
438444
return ts
439445
end
@@ -456,17 +462,27 @@ function lower_order_var(dervar, t)
456462
diffvar
457463
end
458464

459-
function shift_discrete_system(ts::TearingState)
465+
"""
466+
Shift variable x by the largest shift s such that x(k-s) appears in the system of equations.
467+
The lowest-shift term will have.
468+
"""
469+
function shift_discrete_system(ts::TearingState, lowest_shift)
460470
@unpack fullvars, sys = ts
471+
return ts
461472
discvars = OrderedSet()
462473
eqs = equations(sys)
474+
463475
for eq in eqs
464476
vars!(discvars, eq; op = Union{Sample, Hold})
465477
end
466478
iv = get_iv(sys)
467-
discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))
479+
discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, -get(lowest_shift, k, 0))(k))
468480
for k in discvars
469-
if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
481+
if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
482+
483+
discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) for k in discvars
484+
if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
485+
470486
for i in eachindex(fullvars)
471487
fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute(
472488
fullvars[i], discmap; operator = Union{Sample, Hold}))

0 commit comments

Comments
 (0)