Skip to content

Commit eb561a9

Browse files
fixup! fixup! fix: improve hack supporting unscalarized usage of array observed variables
1 parent b1fddf4 commit eb561a9

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,14 +587,18 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
587587
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588588
# are equations, add an equation `p ~ [p[1], p[2], ...]`
589589
# allow topsort to reorder them
590-
# only add the new equation if all `p[i]` are present
590+
# only add the new equation if all `p[i]` are present and the unscalarized
591+
# form is used in any equation (observed or not)
591592
# we first count the number of times the scalarized form of each observed
592593
# variable occurs in observed equations (and unknowns if it's split).
593594

594595
# map of array observed variable (unscalarized) to number of its
595596
# scalarized terms that appear in observed equations
596597
arr_obs_occurrences = Dict()
598+
# to check if array variables occur in unscalarized form anywhere
599+
all_vars = Set()
597600
for eq in obs
601+
vars!(all_vars, eq.rhs)
598602
lhs = eq.lhs
599603
iscall(lhs) || continue
600604
operation(lhs) === getindex || continue
@@ -616,13 +620,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
616620
cnt == 0 && continue
617621
arr_obs_occurrences[arg1] = cnt + 1
618622
end
623+
for eq in neweqs
624+
vars!(all_vars, eq.rhs)
625+
end
619626
obs_arr_eqs = Equation[]
620627
for (arrvar, cnt) in arr_obs_occurrences
621628
cnt == length(arrvar) || continue
629+
arrvar in all_vars || continue
622630
# firstindex returns 1 for multidimensional array symbolics
623631
firstind = first(eachindex(arrvar))
624632
scal = [arrvar[i] for i in eachindex(arrvar)]
625-
all(sym -> any(eq -> isequal(sym, eq.lhs)))
626633
# respect non-1-indexed arrays
627634
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
628635
# `change_origin` is required because `Origin(firstind)(scal)` makes codegen

0 commit comments

Comments
 (0)