Skip to content

Commit 5c2edf4

Browse files
fix: improve hack supporting unscalarized usage of array observed variables
1 parent a03f061 commit 5c2edf4

File tree

2 files changed

+83
-29
lines changed

2 files changed

+83
-29
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -574,48 +574,78 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574
# TODO: compute the dependency correctly so that we don't have to do this
575575
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
576576

577-
# HACK: Substitute non-scalarized symbolic arrays of observed variables
578-
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579-
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580-
# by the topological sorting and dependency identification pieces
581-
obs_arr_subs = Dict()
577+
unknowns = Any[v
578+
for (i, v) in enumerate(fullvars)
579+
if diff_to_var[i] === nothing && ispresent(i)]
580+
if !isempty(extra_vars)
581+
for v in extra_vars
582+
push!(unknowns, old_fullvars[v])
583+
end
584+
end
585+
@set! sys.unknowns = unknowns
582586

587+
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588+
# are equations, add an equation `p ~ [p[1], p[2], ...]`
589+
# allow topsort to reorder them
590+
# only add the new equation if all `p[i]` are present and the unscalarized
591+
# form is used in any equation (observed or not)
592+
# we first count the number of times the scalarized form of each observed
593+
# variable occurs in observed equations (and unknowns if it's split).
594+
595+
# map of array observed variable (unscalarized) to number of its
596+
# scalarized terms that appear in observed equations
597+
arr_obs_occurrences = Dict()
598+
# to check if array variables occur in unscalarized form anywhere
599+
all_vars = Set()
583600
for eq in obs
601+
vars!(all_vars, eq.rhs)
584602
lhs = eq.lhs
585603
iscall(lhs) || continue
586604
operation(lhs) === getindex || continue
587-
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
605+
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
588606
arg1 = arguments(lhs)[1]
589-
haskey(obs_arr_subs, arg1) && continue
590-
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)] # e.g. p => [p[1], p[2]]
591-
index_first = eachindex(arg1)[1]
592-
607+
cnt = get(arr_obs_occurrences, arg1, 0)
608+
arr_obs_occurrences[arg1] = cnt + 1
609+
continue
610+
end
611+
# count variables in unknowns if they are scalarized forms of variables
612+
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
613+
# is an observed equation.
614+
for sym in unknowns
615+
iscall(sym) || continue
616+
operation(sym) === getindex || continue
617+
Symbolics.shape(sym) != Symbolics.Unknown() || continue
618+
arg1 = arguments(sym)[1]
619+
cnt = get(arr_obs_occurrences, arg1, 0)
620+
cnt == 0 && continue
621+
arr_obs_occurrences[arg1] = cnt + 1
622+
end
623+
for eq in neweqs
624+
vars!(all_vars, eq.rhs)
625+
end
626+
obs_arr_eqs = Equation[]
627+
for (arrvar, cnt) in arr_obs_occurrences
628+
cnt == length(arrvar) || continue
629+
arrvar in all_vars || continue
630+
# firstindex returns 1 for multidimensional array symbolics
631+
firstind = first(eachindex(arrvar))
632+
scal = [arrvar[i] for i in eachindex(arrvar)]
593633
# respect non-1-indexed arrays
594634
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
595-
obs_arr_subs[arg1] = Origin(index_first)(obs_arr_subs[arg1])
596-
end
597-
for i in eachindex(neweqs)
598-
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
599-
end
600-
for i in eachindex(obs)
601-
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
602-
end
603-
for i in eachindex(subeqs)
604-
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
635+
# `change_origin` is required because `Origin(firstind)(scal)` makes codegen
636+
# try to `create_array(OffsetArray{...}, ...)` which errors.
637+
# `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
638+
# of `scal`.
639+
push!(obs_arr_eqs, arrvar ~ change_origin(Origin(firstind), scal))
605640
end
641+
append!(obs, obs_arr_eqs)
642+
append!(subeqs, obs_arr_eqs)
643+
# need to re-sort subeqs
644+
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
606645

607646
@set! sys.eqs = neweqs
608647
@set! sys.observed = obs
609648

610-
unknowns = Any[v
611-
for (i, v) in enumerate(fullvars)
612-
if diff_to_var[i] === nothing && ispresent(i)]
613-
if !isempty(extra_vars)
614-
for v in extra_vars
615-
push!(unknowns, old_fullvars[v])
616-
end
617-
end
618-
@set! sys.unknowns = unknowns
619649
@set! sys.substitutions = Substitutions(subeqs, deps)
620650

621651
# Only makes sense for time-dependent
@@ -629,6 +659,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629659
return invalidate_cache!(sys)
630660
end
631661

662+
function change_origin(origin, arr)
663+
return origin(arr)
664+
end
665+
666+
@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin
667+
size = size(arr)
668+
eltype = eltype(arr)
669+
ndims = ndims(arr)
670+
end
671+
632672
function tearing(state::TearingState; kwargs...)
633673
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
634674
complete!(state.structure)

test/structural_transformation/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,17 @@ end
4040
@test ModelingToolkit.𝑠neighbors(g, 1) == [2]
4141
@test ModelingToolkit.𝑑neighbors(g, 2) == [1]
4242
end
43+
44+
@testset "array observed used unscalarized in another observed" begin
45+
@variables x(t) y(t)[1:2] z(t)[1:2]
46+
@parameters foo(::AbstractVector)[1:2]
47+
_tmp_fn(x) = 2x
48+
@mtkbuild sys = ODESystem(
49+
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
50+
@test length(equations(sys)) == 1
51+
@test length(observed(sys)) == 6
52+
@test any(eq -> isequal(eq.lhs, y), observed(sys))
53+
@test any(eq -> isequal(eq.lhs, z), observed(sys))
54+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
55+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
56+
end

0 commit comments

Comments
 (0)