@@ -204,7 +204,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
204204 """ The system of equations."""
205205 sys:: T
206206 """ The set of variables of the system."""
207- fullvars:: Vector
207+ fullvars:: Vector{BasicSymbolic}
208208 structure:: SystemStructure
209209 extra_eqs:: Vector
210210 param_derivative_map:: Dict{BasicSymbolic, Any}
@@ -254,128 +254,164 @@ function Base.push!(ev::EquationsView, eq)
254254 push! (ev. ts. extra_eqs, eq)
255255end
256256
257- function is_time_dependent_parameter (p, iv)
258- return iv != = nothing && isparameter (p) && iscall (p) &&
259- (operation (p) === getindex && is_time_dependent_parameter (arguments (p)[1 ], iv) ||
257+ function is_time_dependent_parameter (p, allps, iv)
258+ return iv != = nothing && p in allps && iscall (p) &&
259+ (operation (p) === getindex &&
260+ is_time_dependent_parameter (arguments (p)[1 ], allps, iv) ||
260261 (args = arguments (p); length (args)) == 1 && isequal (only (args), iv))
261262end
262263
264+ function symbolic_contains (var, set)
265+ var in set ||
266+ symbolic_type (var) == ArraySymbolic () &&
267+ Symbolics. shape (var) != Symbolics. Unknown () &&
268+ all (x -> x in set, Symbolics. scalarize (var))
269+ end
270+
263271function TearingState (sys; quick_cancel = false , check = true , sort_eqs = true )
272+ # flatten system
264273 sys = flatten (sys)
265274 ivs = independent_variables (sys)
266275 iv = length (ivs) == 1 ? ivs[1 ] : nothing
267- # scalarize array equations, without scalarizing arguments to registered functions
268- eqs = flatten_equations (copy ( equations (sys) ))
276+ # flatten array equations
277+ eqs = flatten_equations (equations (sys))
269278 neqs = length (eqs)
270- dervaridxs = OrderedSet {Int} ()
271- var2idx = Dict {Any, Int} ()
272- symbolic_incidence = []
273- fullvars = []
274279 param_derivative_map = Dict {BasicSymbolic, Any} ()
275- var_counter = Ref (0 )
280+ # * Scalarize unknowns
281+ dvs = Set {BasicSymbolic} ()
282+ fullvars = BasicSymbolic[]
283+ for x in unknowns (sys)
284+ push! (dvs, x)
285+ xx = Symbolics. scalarize (x)
286+ if xx isa AbstractArray
287+ union! (dvs, xx)
288+ end
289+ end
290+ ps = Set {Symbolic} ()
291+ for x in full_parameters (sys)
292+ push! (ps, x)
293+ if symbolic_type (x) == ArraySymbolic () && Symbolics. shape (x) != Symbolics. Unknown ()
294+ xx = Symbolics. scalarize (x)
295+ union! (ps, xx)
296+ end
297+ end
298+ browns = Set {BasicSymbolic} ()
299+ for x in brownians (sys)
300+ push! (browns, x)
301+ xx = Symbolics. scalarize (x)
302+ if xx isa AbstractArray
303+ union! (browns, xx)
304+ end
305+ end
306+ var2idx = Dict {BasicSymbolic, Int} ()
276307 var_types = VariableType[]
277- addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
278- var -> get! (var2idx, var) do
308+ addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx, var_types = var_types
309+ (var, vtype) -> get! (var2idx, var) do
310+ push! (dvs, var)
279311 push! (fullvars, var)
280- push! (var_types, getvariabletype (var) )
281- var_counter[] += 1
312+ push! (var_types, vtype )
313+ return length (fullvars)
282314 end
283315 end
284316
285- vars = OrderedSet ()
286- varsvec = []
317+ # build symbolic incidence
318+ symbolic_incidence = Vector{BasicSymbolic}[]
319+ varsbuf = Set ()
287320 eqs_to_retain = trues (length (eqs))
288- for (i, eq′) in enumerate (eqs)
289- if eq′. lhs isa Connection
290- check ? error (" $(nameof (sys)) has unexpanded `connect` statements" ) :
291- return nothing
292- end
293- if iscall (eq′. lhs) && (op = operation (eq′. lhs)) isa Differential &&
294- isequal (op. x, iv) && is_time_dependent_parameter (only (arguments (eq′. lhs)), iv)
321+ for (i, eq) in enumerate (eqs)
322+ if iscall (eq. lhs) && (op = operation (eq. lhs)) isa Differential &&
323+ isequal (op. x, iv) && is_time_dependent_parameter (only (arguments (eq. lhs)), ps, iv)
295324 # parameter derivatives are opted out by specifying `D(p) ~ missing`, but
296325 # we want to store `nothing` in the map because that means `fast_substitute`
297326 # will ignore the rule. We will this identify the presence of `eq′.lhs` in
298327 # the differentiated expression and error.
299- param_derivative_map[eq′ . lhs] = coalesce (eq′ . rhs, nothing )
328+ param_derivative_map[eq. lhs] = coalesce (eq. rhs, nothing )
300329 eqs_to_retain[i] = false
301330 # change the equation if the RHS is `missing` so the rest of this loop works
302- eq′ = eq′ . lhs ~ coalesce (eq′ . rhs, 0.0 )
331+ eq = 0.0 ~ coalesce (eq. rhs, 0.0 )
303332 end
304- if _iszero (eq′. lhs)
305- rhs = quick_cancel ? quick_cancel_expr (eq′. rhs) : eq′. rhs
306- eq = eq′
307- else
308- lhs = quick_cancel ? quick_cancel_expr (eq′. lhs) : eq′. lhs
309- rhs = quick_cancel ? quick_cancel_expr (eq′. rhs) : eq′. rhs
333+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
334+ if ! _iszero (eq. lhs)
335+ lhs = quick_cancel ? quick_cancel_expr (eq. lhs) : eq. lhs
310336 eq = 0 ~ rhs - lhs
311337 end
312- vars! (vars, eq. rhs, op = Symbolics. Operator)
313- for v in vars
314- _var, _ = var_from_nested_derivative (v)
315- any (isequal (_var), ivs) && continue
316- if isparameter (_var) ||
317- (iscall (_var) && isparameter (operation (_var)))
318- if is_time_dependent_parameter (_var, iv) &&
319- ! haskey (param_derivative_map, Differential (iv)(_var))
338+ empty! (varsbuf)
339+ vars! (varsbuf, eq; op = Symbolics. Operator)
340+ incidence = Set {BasicSymbolic} ()
341+ isalgeq = true
342+ for v in varsbuf
343+ # additionally track brownians in fullvars
344+ if v in browns
345+ addvar! (v, BROWNIAN)
346+ push! (incidence, v)
347+ end
348+
349+ # TODO : Can we handle this without `isparameter`?
350+ if symbolic_contains (v, ps) ||
351+ getmetadata (v, SymScope, LocalScope ()) isa GlobalScope && isparameter (v)
352+ if is_time_dependent_parameter (v, ps, iv) &&
353+ ! haskey (param_derivative_map, Differential (iv)(v))
320354 # Parameter derivatives default to zero - they stay constant
321355 # between callbacks
322- param_derivative_map[Differential (iv)(_var )] = 0.0
356+ param_derivative_map[Differential (iv)(v )] = 0.0
323357 end
324358 continue
325359 end
326- v = scalarize (v)
327- if v isa AbstractArray
328- append! (varsvec, v)
329- else
330- push! (varsvec, v)
331- end
332- end
333- isalgeq = true
334- unknownvars = []
335- for var in varsvec
336- ModelingToolkit. isdelay (var, iv) && continue
337- set_incidence = true
338- @label ANOTHER_VAR
339- _var, _ = var_from_nested_derivative (var)
340- any (isequal (_var), ivs) && continue
341- if isparameter (_var) ||
342- (iscall (_var) && isparameter (operation (_var)))
343- continue
344- end
345- varidx = addvar! (var)
346- set_incidence && push! (unknownvars, var)
347-
348- dvar = var
349- idx = varidx
350- while isdifferential (dvar)
351- if ! (idx in dervaridxs)
352- push! (dervaridxs, idx)
360+
361+ isequal (v, iv) && continue
362+ isdelay (v, iv) && continue
363+
364+ if ! symbolic_contains (v, dvs)
365+ isvalid = iscall (v) && operation (v) isa Union{Shift, Sample, Hold}
366+ v′ = v
367+ while ! isvalid && iscall (v′) && operation (v′) isa Union{Differential, Shift}
368+ v′ = arguments (v′)[1 ]
369+ if v′ in dvs || getmetadata (v′, SymScope, LocalScope ()) isa GlobalScope
370+ isvalid = true
371+ break
372+ end
373+ end
374+ if ! isvalid
375+ throw (ArgumentError (" $v is present in the system but $v′ is not an unknown." ))
376+ end
377+
378+ addvar! (v, VARIABLE)
379+ if iscall (v) && operation (v) isa Symbolics. Operator && ! isdifferential (v) &&
380+ (it = input_timedomain (v)) != = nothing
381+ v′ = only (arguments (v))
382+ addvar! (setmetadata (v′, VariableTimeDomain, it), VARIABLE)
353383 end
354- isalgeq = false
355- dvar = arguments (dvar)[1 ]
356- idx = addvar! (dvar)
357384 end
358385
359- dvar = var
360- idx = varidx
386+ isalgeq &= ! isdifferential (v)
361387
362- if iscall (var) && operation (var) isa Symbolics. Operator &&
363- ! isdifferential (var) && (it = input_timedomain (var)) != = nothing
364- set_incidence = false
365- var = only (arguments (var))
366- var = setmetadata (var, VariableTimeDomain, it)
367- @goto ANOTHER_VAR
388+ if symbolic_type (v) == ArraySymbolic ()
389+ vv = collect (v)
390+ union! (incidence, vv)
391+ map (vv) do vi
392+ addvar! (vi, VARIABLE)
393+ end
394+ else
395+ push! (incidence, v)
396+ addvar! (v, VARIABLE)
368397 end
369398 end
370- push! (symbolic_incidence, copy (unknownvars))
371- empty! (unknownvars)
372- empty! (vars)
373- empty! (varsvec)
399+
374400 if isalgeq
375401 eqs[i] = eq
376402 else
377403 eqs[i] = eqs[i]. lhs ~ rhs
378404 end
405+ push! (symbolic_incidence, collect (incidence))
406+ end
407+
408+ dervaridxs = OrderedSet {Int} ()
409+ for (i, v) in enumerate (fullvars)
410+ while isdifferential (v)
411+ push! (dervaridxs, i)
412+ v = arguments (v)[1 ]
413+ i = addvar! (v, VARIABLE)
414+ end
379415 end
380416 eqs = eqs[eqs_to_retain]
381417 neqs = length (eqs)
@@ -389,6 +425,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
389425 symbolic_incidence = symbolic_incidence[sortidxs]
390426 end
391427
428+ # Handle shifts - find lowest shift and add intermediates with derivative edges
392429 # ## Handle discrete variables
393430 lowest_shift = Dict ()
394431 for var in fullvars
@@ -422,12 +459,13 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
422459 for s in (steps - 1 ): - 1 : (lshift + 1 )
423460 sf = Shift (tt, s)
424461 dvar = sf (v)
425- idx = addvar! (dvar)
462+ idx = addvar! (dvar, VARIABLE )
426463 if ! (idx in dervaridxs)
427464 push! (dervaridxs, idx)
428465 end
429466 end
430467 end
468+
431469 # sort `fullvars` such that the mass matrix is as diagonal as possible.
432470 dervaridxs = collect (dervaridxs)
433471 sorted_fullvars = OrderedSet (fullvars[dervaridxs])
@@ -451,6 +489,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
451489 var2idx = Dict (fullvars .=> eachindex (fullvars))
452490 dervaridxs = 1 : length (dervaridxs)
453491
492+ # build `var_to_diff`
454493 nvars = length (fullvars)
455494 diffvars = []
456495 var_to_diff = DiffGraph (nvars, true )
@@ -462,6 +501,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
462501 var_to_diff[diffvaridx] = dervaridx
463502 end
464503
504+ # build incidence graph
465505 graph = BipartiteGraph (neqs, nvars, Val (false ))
466506 for (ie, vars) in enumerate (symbolic_incidence), v in vars
467507 jv = var2idx[v]
@@ -731,3 +771,21 @@ function _structural_simplify!(state::TearingState; simplify = false,
731771
732772 ModelingToolkit. invalidate_cache! (sys)
733773end
774+
775+ struct DifferentiatedVariableNotUnknownError <: Exception
776+ differentiated:: Any
777+ undifferentiated:: Any
778+ end
779+
780+ function Base. showerror (io:: IO , err:: DifferentiatedVariableNotUnknownError )
781+ undiff = err. undifferentiated
782+ diff = err. differentiated
783+ print (io,
784+ " Variable $undiff occurs differentiated as $diff but is not an unknown of the system." )
785+ scope = getmetadata (undiff, SymScope, LocalScope ())
786+ depth = expected_scope_depth (scope)
787+ if depth > 0
788+ print (io,
789+ " \n Variable $undiff expects $depth more levels in the hierarchy to be an unknown." )
790+ end
791+ end
0 commit comments