Skip to content

Commit 74d69d8

Browse files
fix: improve hack supporting unscalarized usage of array observed variables
1 parent a87eb46 commit 74d69d8

File tree

2 files changed

+43
-20
lines changed

2 files changed

+43
-20
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -574,35 +574,35 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574
# TODO: compute the dependency correctly so that we don't have to do this
575575
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
576576

577-
# HACK: Substitute non-scalarized symbolic arrays of observed variables
578-
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579-
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580-
# by the topological sorting and dependency identification pieces
581-
obs_arr_subs = Dict()
577+
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
578+
# are equations, add an equation `p ~ [p[1], p[2], ...]`
579+
# allow topsort to reorder them
582580

581+
handled_obs_arr = Set()
582+
obs_arr_eqs = Equation[]
583583
for eq in obs
584584
lhs = eq.lhs
585585
iscall(lhs) || continue
586586
operation(lhs) === getindex || continue
587587
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
588588
arg1 = arguments(lhs)[1]
589-
haskey(obs_arr_subs, arg1) && continue
590-
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)] # e.g. p => [p[1], p[2]]
591-
index_first = eachindex(arg1)[1]
592-
589+
arg1 in handled_obs_arr && continue
590+
# firstindex returns 1 for multidimensional array symbolics
591+
firstind = first(eachindex(arg1))
592+
scal = [arg1[i] for i in eachindex(arg1)]
593593
# respect non-1-indexed arrays
594594
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
595-
obs_arr_subs[arg1] = Origin(index_first)(obs_arr_subs[arg1])
596-
end
597-
for i in eachindex(neweqs)
598-
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
599-
end
600-
for i in eachindex(obs)
601-
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
602-
end
603-
for i in eachindex(subeqs)
604-
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
605-
end
595+
# `change_origin` is required because `Origin(firstind)(scal)` makes codegen
596+
# try to `create_array(OffsetArray{...}, ...)` which errors.
597+
# `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
598+
# of `scal`.
599+
push!(obs_arr_eqs, arg1 ~ change_origin(Origin(firstind), scal))
600+
push!(handled_obs_arr, arg1)
601+
end
602+
append!(obs, obs_arr_eqs)
603+
append!(subeqs, obs_arr_eqs)
604+
# need to re-sort subeqs
605+
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
606606

607607
@set! sys.eqs = neweqs
608608
@set! sys.observed = obs
@@ -629,6 +629,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629629
return invalidate_cache!(sys)
630630
end
631631

632+
function change_origin(origin, arr)
633+
return origin(arr)
634+
end
635+
636+
@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin
637+
size = size(arr)
638+
eltype = eltype(arr)
639+
ndims = ndims(arr)
640+
end
641+
632642
function tearing(state::TearingState; kwargs...)
633643
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
634644
complete!(state.structure)

test/structural_transformation/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,16 @@ end
4040
@test ModelingToolkit.𝑠neighbors(g, 1) == [2]
4141
@test ModelingToolkit.𝑑neighbors(g, 2) == [1]
4242
end
43+
44+
@testset "array observed used unscalarized in another observed" begin
45+
@variables x(t) y(t)[1:2] z(t)[1:2]
46+
@parameters foo(::AbstractVector)[1:2]
47+
_tmp_fn(x) = 2x
48+
@mtkbuild sys = ODESystem([D(x) ~ z[1] + z[2], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
49+
@test length(equations(sys)) == 1
50+
@test length(observed(sys)) == 6
51+
@test any(eq -> isequal(eq.lhs, y), observed(sys))
52+
@test any(eq -> isequal(eq.lhs, z), observed(sys))
53+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
54+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
55+
end

0 commit comments

Comments
 (0)