Skip to content

Commit e583e0a

Browse files
feat: reduce reliance on metadata in structural_simplify
1 parent 7f37457 commit e583e0a

File tree

2 files changed

+143
-85
lines changed

2 files changed

+143
-85
lines changed

src/inputoutput.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
319319

320320
@set! sys.ps = [ps; new_parameters]
321321
@set! state.sys = sys
322-
@set! state.fullvars = new_fullvars
322+
@set! state.fullvars = Vector{BasicSymbolic}(new_fullvars)
323323
@set! state.structure = structure
324324
return state
325325
end

src/systems/systemstructure.jl

Lines changed: 142 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
204204
"""The system of equations."""
205205
sys::T
206206
"""The set of variables of the system."""
207-
fullvars::Vector
207+
fullvars::Vector{BasicSymbolic}
208208
structure::SystemStructure
209209
extra_eqs::Vector
210210
param_derivative_map::Dict{BasicSymbolic, Any}
@@ -254,128 +254,164 @@ function Base.push!(ev::EquationsView, eq)
254254
push!(ev.ts.extra_eqs, eq)
255255
end
256256

257-
function is_time_dependent_parameter(p, iv)
258-
return iv !== nothing && isparameter(p) && iscall(p) &&
259-
(operation(p) === getindex && is_time_dependent_parameter(arguments(p)[1], iv) ||
257+
function is_time_dependent_parameter(p, allps, iv)
258+
return iv !== nothing && p in allps && iscall(p) &&
259+
(operation(p) === getindex &&
260+
is_time_dependent_parameter(arguments(p)[1], allps, iv) ||
260261
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv))
261262
end
262263

264+
function symbolic_contains(var, set)
265+
var in set ||
266+
symbolic_type(var) == ArraySymbolic() &&
267+
Symbolics.shape(var) != Symbolics.Unknown() &&
268+
all(x -> x in set, Symbolics.scalarize(var))
269+
end
270+
263271
function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
272+
# flatten system
264273
sys = flatten(sys)
265274
ivs = independent_variables(sys)
266275
iv = length(ivs) == 1 ? ivs[1] : nothing
267-
# scalarize array equations, without scalarizing arguments to registered functions
268-
eqs = flatten_equations(copy(equations(sys)))
276+
# flatten array equations
277+
eqs = flatten_equations(equations(sys))
269278
neqs = length(eqs)
270-
dervaridxs = OrderedSet{Int}()
271-
var2idx = Dict{Any, Int}()
272-
symbolic_incidence = []
273-
fullvars = []
274279
param_derivative_map = Dict{BasicSymbolic, Any}()
275-
var_counter = Ref(0)
280+
# * Scalarize unknowns
281+
dvs = Set{BasicSymbolic}()
282+
fullvars = BasicSymbolic[]
283+
for x in unknowns(sys)
284+
push!(dvs, x)
285+
xx = Symbolics.scalarize(x)
286+
if xx isa AbstractArray
287+
union!(dvs, xx)
288+
end
289+
end
290+
ps = Set{Symbolic}()
291+
for x in full_parameters(sys)
292+
push!(ps, x)
293+
if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown()
294+
xx = Symbolics.scalarize(x)
295+
union!(ps, xx)
296+
end
297+
end
298+
browns = Set{BasicSymbolic}()
299+
for x in brownians(sys)
300+
push!(browns, x)
301+
xx = Symbolics.scalarize(x)
302+
if xx isa AbstractArray
303+
union!(browns, xx)
304+
end
305+
end
306+
var2idx = Dict{BasicSymbolic, Int}()
276307
var_types = VariableType[]
277-
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
278-
var -> get!(var2idx, var) do
308+
addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx, var_types = var_types
309+
(var, vtype) -> get!(var2idx, var) do
310+
push!(dvs, var)
279311
push!(fullvars, var)
280-
push!(var_types, getvariabletype(var))
281-
var_counter[] += 1
312+
push!(var_types, vtype)
313+
return length(fullvars)
282314
end
283315
end
284316

285-
vars = OrderedSet()
286-
varsvec = []
317+
# build symbolic incidence
318+
symbolic_incidence = Vector{BasicSymbolic}[]
319+
varsbuf = Set()
287320
eqs_to_retain = trues(length(eqs))
288-
for (i, eq′) in enumerate(eqs)
289-
if eq′.lhs isa Connection
290-
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
291-
return nothing
292-
end
293-
if iscall(eq′.lhs) && (op = operation(eq′.lhs)) isa Differential &&
294-
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), iv)
321+
for (i, eq) in enumerate(eqs)
322+
if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential &&
323+
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv)
295324
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
296325
# we want to store `nothing` in the map because that means `fast_substitute`
297326
# will ignore the rule. We will this identify the presence of `eq′.lhs` in
298327
# the differentiated expression and error.
299-
param_derivative_map[eq.lhs] = coalesce(eq.rhs, nothing)
328+
param_derivative_map[eq.lhs] = coalesce(eq.rhs, nothing)
300329
eqs_to_retain[i] = false
301330
# change the equation if the RHS is `missing` so the rest of this loop works
302-
eq = eq′.lhs ~ coalesce(eq.rhs, 0.0)
331+
eq = 0.0 ~ coalesce(eq.rhs, 0.0)
303332
end
304-
if _iszero(eq′.lhs)
305-
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
306-
eq = eq′
307-
else
308-
lhs = quick_cancel ? quick_cancel_expr(eq′.lhs) : eq′.lhs
309-
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
333+
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
334+
if !_iszero(eq.lhs)
335+
lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs
310336
eq = 0 ~ rhs - lhs
311337
end
312-
vars!(vars, eq.rhs, op = Symbolics.Operator)
313-
for v in vars
314-
_var, _ = var_from_nested_derivative(v)
315-
any(isequal(_var), ivs) && continue
316-
if isparameter(_var) ||
317-
(iscall(_var) && isparameter(operation(_var)))
318-
if is_time_dependent_parameter(_var, iv) &&
319-
!haskey(param_derivative_map, Differential(iv)(_var))
338+
empty!(varsbuf)
339+
vars!(varsbuf, eq; op = Symbolics.Operator)
340+
incidence = Set{BasicSymbolic}()
341+
isalgeq = true
342+
for v in varsbuf
343+
# additionally track brownians in fullvars
344+
if v in browns
345+
addvar!(v, BROWNIAN)
346+
push!(incidence, v)
347+
end
348+
349+
# TODO: Can we handle this without `isparameter`?
350+
if symbolic_contains(v, ps) ||
351+
getmetadata(v, SymScope, LocalScope()) isa GlobalScope && isparameter(v)
352+
if is_time_dependent_parameter(v, ps, iv) &&
353+
!haskey(param_derivative_map, Differential(iv)(v))
320354
# Parameter derivatives default to zero - they stay constant
321355
# between callbacks
322-
param_derivative_map[Differential(iv)(_var)] = 0.0
356+
param_derivative_map[Differential(iv)(v)] = 0.0
323357
end
324358
continue
325359
end
326-
v = scalarize(v)
327-
if v isa AbstractArray
328-
append!(varsvec, v)
329-
else
330-
push!(varsvec, v)
331-
end
332-
end
333-
isalgeq = true
334-
unknownvars = []
335-
for var in varsvec
336-
ModelingToolkit.isdelay(var, iv) && continue
337-
set_incidence = true
338-
@label ANOTHER_VAR
339-
_var, _ = var_from_nested_derivative(var)
340-
any(isequal(_var), ivs) && continue
341-
if isparameter(_var) ||
342-
(iscall(_var) && isparameter(operation(_var)))
343-
continue
344-
end
345-
varidx = addvar!(var)
346-
set_incidence && push!(unknownvars, var)
347-
348-
dvar = var
349-
idx = varidx
350-
while isdifferential(dvar)
351-
if !(idx in dervaridxs)
352-
push!(dervaridxs, idx)
360+
361+
isequal(v, iv) && continue
362+
isdelay(v, iv) && continue
363+
364+
if !symbolic_contains(v, dvs)
365+
isvalid = iscall(v) && operation(v) isa Union{Shift, Sample, Hold}
366+
v′ = v
367+
while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift}
368+
v′ = arguments(v′)[1]
369+
if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope
370+
isvalid = true
371+
break
372+
end
373+
end
374+
if !isvalid
375+
throw(ArgumentError("$v is present in the system but $v′ is not an unknown."))
376+
end
377+
378+
addvar!(v, VARIABLE)
379+
if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) &&
380+
(it = input_timedomain(v)) !== nothing
381+
v′ = only(arguments(v))
382+
addvar!(setmetadata(v′, VariableTimeDomain, it), VARIABLE)
353383
end
354-
isalgeq = false
355-
dvar = arguments(dvar)[1]
356-
idx = addvar!(dvar)
357384
end
358385

359-
dvar = var
360-
idx = varidx
386+
isalgeq &= !isdifferential(v)
361387

362-
if iscall(var) && operation(var) isa Symbolics.Operator &&
363-
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
364-
set_incidence = false
365-
var = only(arguments(var))
366-
var = setmetadata(var, VariableTimeDomain, it)
367-
@goto ANOTHER_VAR
388+
if symbolic_type(v) == ArraySymbolic()
389+
vv = collect(v)
390+
union!(incidence, vv)
391+
map(vv) do vi
392+
addvar!(vi, VARIABLE)
393+
end
394+
else
395+
push!(incidence, v)
396+
addvar!(v, VARIABLE)
368397
end
369398
end
370-
push!(symbolic_incidence, copy(unknownvars))
371-
empty!(unknownvars)
372-
empty!(vars)
373-
empty!(varsvec)
399+
374400
if isalgeq
375401
eqs[i] = eq
376402
else
377403
eqs[i] = eqs[i].lhs ~ rhs
378404
end
405+
push!(symbolic_incidence, collect(incidence))
406+
end
407+
408+
dervaridxs = OrderedSet{Int}()
409+
for (i, v) in enumerate(fullvars)
410+
while isdifferential(v)
411+
push!(dervaridxs, i)
412+
v = arguments(v)[1]
413+
i = addvar!(v, VARIABLE)
414+
end
379415
end
380416
eqs = eqs[eqs_to_retain]
381417
neqs = length(eqs)
@@ -389,6 +425,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
389425
symbolic_incidence = symbolic_incidence[sortidxs]
390426
end
391427

428+
# Handle shifts - find lowest shift and add intermediates with derivative edges
392429
### Handle discrete variables
393430
lowest_shift = Dict()
394431
for var in fullvars
@@ -422,12 +459,13 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
422459
for s in (steps - 1):-1:(lshift + 1)
423460
sf = Shift(tt, s)
424461
dvar = sf(v)
425-
idx = addvar!(dvar)
462+
idx = addvar!(dvar, VARIABLE)
426463
if !(idx in dervaridxs)
427464
push!(dervaridxs, idx)
428465
end
429466
end
430467
end
468+
431469
# sort `fullvars` such that the mass matrix is as diagonal as possible.
432470
dervaridxs = collect(dervaridxs)
433471
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
@@ -451,6 +489,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
451489
var2idx = Dict(fullvars .=> eachindex(fullvars))
452490
dervaridxs = 1:length(dervaridxs)
453491

492+
# build `var_to_diff`
454493
nvars = length(fullvars)
455494
diffvars = []
456495
var_to_diff = DiffGraph(nvars, true)
@@ -462,6 +501,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
462501
var_to_diff[diffvaridx] = dervaridx
463502
end
464503

504+
# build incidence graph
465505
graph = BipartiteGraph(neqs, nvars, Val(false))
466506
for (ie, vars) in enumerate(symbolic_incidence), v in vars
467507
jv = var2idx[v]
@@ -731,3 +771,21 @@ function _structural_simplify!(state::TearingState; simplify = false,
731771

732772
ModelingToolkit.invalidate_cache!(sys)
733773
end
774+
775+
struct DifferentiatedVariableNotUnknownError <: Exception
776+
differentiated::Any
777+
undifferentiated::Any
778+
end
779+
780+
function Base.showerror(io::IO, err::DifferentiatedVariableNotUnknownError)
781+
undiff = err.undifferentiated
782+
diff = err.differentiated
783+
print(io,
784+
"Variable $undiff occurs differentiated as $diff but is not an unknown of the system.")
785+
scope = getmetadata(undiff, SymScope, LocalScope())
786+
depth = expected_scope_depth(scope)
787+
if depth > 0
788+
print(io,
789+
"\nVariable $undiff expects $depth more levels in the hierarchy to be an unknown.")
790+
end
791+
end

0 commit comments

Comments
 (0)