Skip to content

Commit b1fddf4

Browse files
fixup! fix: improve hack supporting unscalarized usage of array observed variables
1 parent 7e17bda commit b1fddf4

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -574,30 +574,62 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574
# TODO: compute the dependency correctly so that we don't have to do this
575575
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
576576

577+
unknowns = Any[v
578+
for (i, v) in enumerate(fullvars)
579+
if diff_to_var[i] === nothing && ispresent(i)]
580+
if !isempty(extra_vars)
581+
for v in extra_vars
582+
push!(unknowns, old_fullvars[v])
583+
end
584+
end
585+
@set! sys.unknowns = unknowns
586+
577587
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
578588
# are equations, add an equation `p ~ [p[1], p[2], ...]`
579589
# allow topsort to reorder them
590+
# only add the new equation if all `p[i]` are present
591+
# we first count the number of times the scalarized form of each observed
592+
# variable occurs in observed equations (and unknowns if it's split).
580593

581-
handled_obs_arr = Set()
582-
obs_arr_eqs = Equation[]
594+
# map of array observed variable (unscalarized) to number of its
595+
# scalarized terms that appear in observed equations
596+
arr_obs_occurrences = Dict()
583597
for eq in obs
584598
lhs = eq.lhs
585599
iscall(lhs) || continue
586600
operation(lhs) === getindex || continue
587-
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
601+
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
588602
arg1 = arguments(lhs)[1]
589-
arg1 in handled_obs_arr && continue
603+
cnt = get(arr_obs_occurrences, arg1, 0)
604+
arr_obs_occurrences[arg1] = cnt + 1
605+
continue
606+
end
607+
# count variables in unknowns if they are scalarized forms of variables
608+
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
609+
# is an observed equation.
610+
for sym in unknowns
611+
iscall(sym) || continue
612+
operation(sym) === getindex || continue
613+
Symbolics.shape(sym) != Symbolics.Unknown() || continue
614+
arg1 = arguments(sym)[1]
615+
cnt = get(arr_obs_occurrences, arg1, 0)
616+
cnt == 0 && continue
617+
arr_obs_occurrences[arg1] = cnt + 1
618+
end
619+
obs_arr_eqs = Equation[]
620+
for (arrvar, cnt) in arr_obs_occurrences
621+
cnt == length(arrvar) || continue
590622
# firstindex returns 1 for multidimensional array symbolics
591-
firstind = first(eachindex(arg1))
592-
scal = [arg1[i] for i in eachindex(arg1)]
623+
firstind = first(eachindex(arrvar))
624+
scal = [arrvar[i] for i in eachindex(arrvar)]
625+
all(sym -> any(eq -> isequal(sym, eq.lhs)))
593626
# respect non-1-indexed arrays
594627
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
595628
# `change_origin` is required because `Origin(firstind)(scal)` makes codegen
596629
# try to `create_array(OffsetArray{...}, ...)` which errors.
597630
# `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
598631
# of `scal`.
599-
push!(obs_arr_eqs, arg1 ~ change_origin(Origin(firstind), scal))
600-
push!(handled_obs_arr, arg1)
632+
push!(obs_arr_eqs, arrvar ~ change_origin(Origin(firstind), scal))
601633
end
602634
append!(obs, obs_arr_eqs)
603635
append!(subeqs, obs_arr_eqs)
@@ -607,15 +639,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
607639
@set! sys.eqs = neweqs
608640
@set! sys.observed = obs
609641

610-
unknowns = Any[v
611-
for (i, v) in enumerate(fullvars)
612-
if diff_to_var[i] === nothing && ispresent(i)]
613-
if !isempty(extra_vars)
614-
for v in extra_vars
615-
push!(unknowns, old_fullvars[v])
616-
end
617-
end
618-
@set! sys.unknowns = unknowns
619642
@set! sys.substitutions = Substitutions(subeqs, deps)
620643

621644
# Only makes sense for time-dependent

0 commit comments

Comments
 (0)