Skip to content

Commit 1c578c3

Browse files
committed
rename functions
1 parent 56829f7 commit 1c578c3

File tree

3 files changed

+98
-16
lines changed

3 files changed

+98
-16
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_
254254
diff_to_var = invview(var_to_diff)
255255
iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
256256

257+
@show neweqs
257258
for var in 1:length(fullvars)
259+
#@show neweqs
258260
dv = var_to_diff[var]
259261
dv === nothing && continue
260262
if var_eq_matching[var] !== SelectedState()
@@ -286,9 +288,7 @@ function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_
286288
end
287289
end
288290

289-
"""
290-
Generate new derivative variables for the system.
291-
291+
#=
292292
There are three cases where we want to generate new variables to convert
293293
the system into first order (semi-implicit) ODEs.
294294
@@ -361,6 +361,16 @@ This is different from the continuous case where D(D(x)) can be substituted for
361361
by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
362362
shift_sub dict is updated at the time that the renamed variables are written,
363363
inside the loop where new variables are generated.
364+
=#
365+
"""
366+
Generate new derivative variables for the system.
367+
368+
Effects on the state:
369+
- fullvars: add the new derivative variables x_t
370+
- neweqs: add the identity equations for the new variables, D(x) ~ x_t
371+
- graph: update graph with the new equations and variables, and their connections
372+
- solvable_graph:
373+
- var_eq_matching: solvable equations
364374
"""
365375
function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching;
366376
is_discrete = false, mm = nothing, shift_sub = nothing)
@@ -406,7 +416,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
406416
dx = fullvars[dv]
407417
order, lv = var_order(diff_to_var, dv)
408418
@show fullvars[uv]
409-
x_t = lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv])
419+
x_t = lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv], allow_zero = true)
410420
shift_sub[dx] = x_t
411421
(var_eq_matching[dv] isa Int) ? continue : @goto DISCRETE_VARIABLE
412422
end
@@ -439,7 +449,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
439449
x_t = lower_name(fullvars[lv], iv, order)
440450

441451
@label DISCRETE_VARIABLE
442-
push!(fullvars, x_t)
452+
push!(fullvars, simplify_shifts(x_t))
443453
v_t = length(fullvars)
444454
v_t_idx = add_vertex!(var_to_diff)
445455
add_vertex!(graph, DST)
@@ -448,7 +458,6 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
448458
add_vertex!(solvable_graph, DST)
449459
push!(var_eq_matching, unassigned)
450460

451-
# Add discrete substitutions to total_sub directly.
452461
is_discrete && begin
453462
idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
454463
for e in 𝑑neighbors(graph, dv)
@@ -463,7 +472,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
463472
end
464473

465474
# add `D(x) - x_t ~ 0`
466-
push!(neweqs, 0 ~ x_t - dx)
475+
push!(neweqs, 0 ~ dx - x_t)
467476
add_vertex!(graph, SRC)
468477
dummy_eq = length(neweqs)
469478
add_edge!(graph, dummy_eq, dv)
@@ -478,12 +487,11 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
478487
end
479488
end
480489

481-
function add_solvable_variable!()
490+
function add_solvable_variable!(state::TearingState)
482491

483492
end
484493

485-
function add_solvable_equation!()
486-
494+
function add_solvable_equation!(s::SystemStructure, neweqs, eq)
487495
end
488496

489497
"""
@@ -500,7 +508,8 @@ such that the mass matrix is:
500508
Update the state to account for the new ordering and equations.
501509
502510
####### DISCRETE CASE
503-
- only substitute Shift(t, -2)
511+
- Differential equations: substitute variables with everything shifted forward one timestep.
512+
- Algebraic and observable equations: substitute variables with everything shifted back one timestep.
504513
"""
505514
function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, shift_sub = Dict())
506515
@unpack fullvars, sys, structure = state
@@ -562,8 +571,6 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
562571
isnothing(D) &&
563572
error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
564573
order, lv = var_order(diff_to_var, iv)
565-
@show fullvars[lv]
566-
@show simplify_shifts(fullvars[lv])
567574
dx = D(simplify_shifts(fullvars[lv]))
568575
eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
569576
Symbolics.symbolic_linear_solve(neweqs[ieq],
@@ -576,8 +583,6 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
576583
push!(diff_eqs, eq)
577584
total_sub[simplify_shifts(eq.lhs)] = eq.rhs
578585
dx_sub[simplify_shifts(eq.lhs)] = eq.rhs
579-
@show total_sub
580-
@show eq
581586
push!(diffeq_idxs, ieq)
582587
push!(diff_vars, diff_to_var[iv])
583588
continue
@@ -611,6 +616,8 @@ function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_match
611616
push!(algeeq_idxs, ieq)
612617
end
613618
end
619+
@show neweqs
620+
@show subeqs
614621

615622
# TODO: BLT sorting
616623
neweqs = [diff_eqs; alge_eqs]

src/structural_transformation/utils.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,16 @@ end
271271

272272
function find_solvables!(state::TearingState; kwargs...)
273273
@assert state.structure.solvable_graph === nothing
274+
println("in find_solvables")
275+
@show eqs
274276
eqs = equations(state)
275277
graph = state.structure.graph
276278
state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
277279
to_rm = Int[]
278280
for ieq in 1:length(eqs)
279281
find_eq_solvables!(state, ieq, to_rm; kwargs...)
280282
end
283+
@show eqs
281284
return nothing
282285
end
283286

@@ -477,7 +480,7 @@ end
477480
### Rules
478481
# 1. x(t) -> x(t)
479482
# 2. Shift(t, 0)(x(t)) -> x(t)
480-
# 3. Shift(t, 1)(x + z) -> Shift(t, 1)(x) + Shift(t, 1)(z)
483+
# 3. Shift(t, 3)(Shift(t, 2)(x(t)) -> Shift(t, 5)(x(t))
481484

482485
function simplify_shifts(var)
483486
ModelingToolkit.hasshift(var) || return var
@@ -498,3 +501,58 @@ function simplify_shifts(var)
498501
unwrap(var).metadata)
499502
end
500503
end
504+
505+
"""
506+
Power expand the shifts. Used for substitution.
507+
508+
Shift(t, -3)(x(t)) -> Shift(t, -1)(Shift(t, -1)(Shift(t, -1)(x)))
509+
"""
510+
function expand_shifts(var)
511+
ModelingToolkit.hasshift(var) || return var
512+
var = ModelingToolkit.value(var)
513+
514+
var isa Equation && return expand_shifts(var.lhs) ~ expand_shifts(var.rhs)
515+
op = operation(var)
516+
s = sign(op.steps)
517+
arg = only(arguments(var))
518+
519+
if ModelingToolkit.isvariable(arg) && (ModelingToolkit.getvariabletype(arg) === VARIABLE) && isequal(op.t, only(arguments(arg)))
520+
out = arg
521+
for i in 1:op.steps
522+
out = Shift(op.t, s)(out)
523+
end
524+
return out
525+
elseif iscall(arg)
526+
return maketerm(typeof(var), operation(var), expand_shifts.(arguments(var)),
527+
unwrap(var).metadata)
528+
else
529+
return arg
530+
end
531+
end
532+
533+
"""
534+
Shift(t, 1)(x + z) -> Shift(t, 1)(x) + Shift(t, 1)(z)
535+
"""
536+
function distribute_shift(var)
537+
ModelingToolkit.hasshift(var) || return var
538+
var isa Equation && return distribute_shift(var.lhs) ~ distribute_shift(var.rhs)
539+
shift = operation(var)
540+
expr = only(arguments(var))
541+
_distribute_shift(expr, shift)
542+
end
543+
544+
function _distribute_shift(expr, shift)
545+
op = operation(expr)
546+
args = arguments(expr)
547+
548+
if length(args) == 1
549+
if ModelingToolkit.isvariable(only(args)) && isequal(op.t, only(args))
550+
return shift(only(args))
551+
else
552+
return only(args)
553+
end
554+
else iscall(op)
555+
return maketerm(typeof(expr), operation(expr), _distribute_shift.(args, shift),
556+
unwrap(var).metadata)
557+
end
558+
end

test/structural_transformation/utils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Graphs
44
using SparseArrays
55
using UnPack
66
using ModelingToolkit: t_nounits as t, D_nounits as D
7+
const ST = StructuralTransformations
78

89
# Define some variables
910
@parameters L g
@@ -161,3 +162,19 @@ end
161162
structural_simplify(sys; additional_passes = [pass])
162163
@test value[] == 1
163164
end
165+
166+
@testset "Shift simplification" begin
167+
@variables x(t) y(t) z(t)
168+
@parameters a b c
169+
170+
# Expand shifts
171+
@test isequal(ST.expand_shifts(Shift(t, -3)(x)), Shift(t, -1)(Shift(t, -1)(Shift(t, -1)(x))))
172+
expr = a * Shift(t, -2)(x) + Shift(t, 2)(y) + b
173+
@test isequal(ST.expand_shifts(expr),
174+
a * Shift(t, -1)(Shift(t, -1)(x)) + Shift(t, 1)(Shift(t, 1)(y)) + b)
175+
@test isequal(ST.expand_shifts(Shift(t, 2)(Shift(t, 1)(a))), a)
176+
177+
178+
# Distribute shifts
179+
180+
end

0 commit comments

Comments
 (0)