Skip to content

Commit c8ac522

Browse files
Merge pull request #3126 from AayushSabharwal/as/symarray-fixes
fix: some fixes related to usage of array symbolics
2 parents ba519bf + 3f7d45a commit c8ac522

File tree

5 files changed

+354
-46
lines changed

5 files changed

+354
-46
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 174 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ end
230230
=#
231231

232232
function tearing_reassemble(state::TearingState, var_eq_matching,
233-
full_var_eq_matching = nothing; simplify = false, mm = nothing)
233+
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
234234
@unpack fullvars, sys, structure = state
235235
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
236236
extra_vars = Int[]
@@ -574,39 +574,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574
# TODO: compute the dependency correctly so that we don't have to do this
575575
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
576576

577-
# HACK: Substitute non-scalarized symbolic arrays of observed variables
578-
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579-
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580-
# by the topological sorting and dependency identification pieces
581-
obs_arr_subs = Dict()
582-
583-
for eq in obs
584-
lhs = eq.lhs
585-
iscall(lhs) || continue
586-
operation(lhs) === getindex || continue
587-
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
588-
arg1 = arguments(lhs)[1]
589-
haskey(obs_arr_subs, arg1) && continue
590-
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)] # e.g. p => [p[1], p[2]]
591-
index_first = eachindex(arg1)[1]
592-
593-
# respect non-1-indexed arrays
594-
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
595-
obs_arr_subs[arg1] = Origin(index_first)(obs_arr_subs[arg1])
596-
end
597-
for i in eachindex(neweqs)
598-
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
599-
end
600-
for i in eachindex(obs)
601-
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
602-
end
603-
for i in eachindex(subeqs)
604-
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
605-
end
606-
607-
@set! sys.eqs = neweqs
608-
@set! sys.observed = obs
609-
610577
unknowns = Any[v
611578
for (i, v) in enumerate(fullvars)
612579
if diff_to_var[i] === nothing && ispresent(i)]
@@ -616,6 +583,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
616583
end
617584
end
618585
@set! sys.unknowns = unknowns
586+
587+
obs, subeqs, deps = cse_and_array_hacks(
588+
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
589+
590+
@set! sys.eqs = neweqs
591+
@set! sys.observed = obs
592+
619593
@set! sys.substitutions = Substitutions(subeqs, deps)
620594

621595
# Only makes sense for time-dependent
@@ -629,6 +603,168 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629603
return invalidate_cache!(sys)
630604
end
631605

606+
"""
607+
# HACK 1
608+
609+
Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
610+
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
611+
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
612+
avoid the unnecessary cost. This and the below hack are implemented simultaneously
613+
614+
# HACK 2
615+
616+
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
617+
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
618+
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
619+
not) we first count the number of times the scalarized form of each observed variable
620+
occurs in observed equations (and unknowns if it's split).
621+
"""
622+
function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true)
623+
# HACK 1
624+
# mapping of rhs to temporary CSE variable
625+
# `f(...) => tmpvar` in above example
626+
rhs_to_tempvar = Dict()
627+
628+
# HACK 2
629+
# map of array observed variable (unscalarized) to number of its
630+
# scalarized terms that appear in observed equations
631+
arr_obs_occurrences = Dict()
632+
# to check if array variables occur in unscalarized form anywhere
633+
all_vars = Set()
634+
for (i, eq) in enumerate(obs)
635+
lhs = eq.lhs
636+
rhs = eq.rhs
637+
vars!(all_vars, rhs)
638+
639+
# HACK 1
640+
if cse && is_getindexed_array(rhs)
641+
rhs_arr = arguments(rhs)[1]
642+
if !haskey(rhs_to_tempvar, rhs_arr)
643+
tempvar = gensym(Symbol(lhs))
644+
N = length(rhs_arr)
645+
tempvar = unwrap(Symbolics.variable(
646+
tempvar; T = Symbolics.symtype(rhs_arr)))
647+
tempvar = setmetadata(
648+
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
649+
tempeq = tempvar ~ rhs_arr
650+
rhs_to_tempvar[rhs_arr] = tempvar
651+
push!(obs, tempeq)
652+
push!(subeqs, tempeq)
653+
end
654+
655+
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
656+
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
657+
# which fails the topological sort
658+
neweq = lhs ~ getindex_wrapper(
659+
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
660+
obs[i] = neweq
661+
subeqi = findfirst(isequal(eq), subeqs)
662+
if subeqi !== nothing
663+
subeqs[subeqi] = neweq
664+
end
665+
end
666+
# end HACK 1
667+
668+
array || continue
669+
iscall(lhs) || continue
670+
operation(lhs) === getindex || continue
671+
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
672+
arg1 = arguments(lhs)[1]
673+
cnt = get(arr_obs_occurrences, arg1, 0)
674+
arr_obs_occurrences[arg1] = cnt + 1
675+
continue
676+
end
677+
678+
# Also do CSE for `equations(sys)`
679+
if cse
680+
for (i, eq) in enumerate(neweqs)
681+
(; lhs, rhs) = eq
682+
is_getindexed_array(rhs) || continue
683+
rhs_arr = arguments(rhs)[1]
684+
if !haskey(rhs_to_tempvar, rhs_arr)
685+
tempvar = gensym(Symbol(lhs))
686+
N = length(rhs_arr)
687+
tempvar = unwrap(Symbolics.variable(
688+
tempvar; T = Symbolics.symtype(rhs_arr)))
689+
tempvar = setmetadata(
690+
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
691+
tempeq = tempvar ~ rhs_arr
692+
rhs_to_tempvar[rhs_arr] = tempvar
693+
push!(obs, tempeq)
694+
push!(subeqs, tempeq)
695+
end
696+
# don't need getindex_wrapper, but do it anyway to know that this
697+
# hack took place
698+
neweq = lhs ~ getindex_wrapper(
699+
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
700+
neweqs[i] = neweq
701+
end
702+
end
703+
704+
# count variables in unknowns if they are scalarized forms of variables
705+
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
706+
# is an observed equation.
707+
for sym in unknowns
708+
iscall(sym) || continue
709+
operation(sym) === getindex || continue
710+
Symbolics.shape(sym) != Symbolics.Unknown() || continue
711+
arg1 = arguments(sym)[1]
712+
cnt = get(arr_obs_occurrences, arg1, 0)
713+
cnt == 0 && continue
714+
arr_obs_occurrences[arg1] = cnt + 1
715+
end
716+
for eq in neweqs
717+
vars!(all_vars, eq.rhs)
718+
end
719+
obs_arr_eqs = Equation[]
720+
for (arrvar, cnt) in arr_obs_occurrences
721+
cnt == length(arrvar) || continue
722+
arrvar in all_vars || continue
723+
# firstindex returns 1 for multidimensional array symbolics
724+
firstind = first(eachindex(arrvar))
725+
scal = [arrvar[i] for i in eachindex(arrvar)]
726+
# respect non-1-indexed arrays
727+
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
728+
# `change_origin` is required because `Origin(firstind)(scal)` makes codegen
729+
# try to `create_array(OffsetArray{...}, ...)` which errors.
730+
# `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
731+
# of `scal`.
732+
push!(obs_arr_eqs, arrvar ~ change_origin(Origin(firstind), scal))
733+
end
734+
append!(obs, obs_arr_eqs)
735+
append!(subeqs, obs_arr_eqs)
736+
737+
# need to re-sort subeqs
738+
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
739+
740+
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
741+
for i in 1:length(subeqs)]
742+
743+
return obs, subeqs, deps
744+
end
745+
746+
function is_getindexed_array(rhs)
747+
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
748+
iscall(rhs) && operation(rhs) === getindex &&
749+
Symbolics.shape(rhs) != Symbolics.Unknown()
750+
end
751+
752+
# PART OF HACK 1
753+
getindex_wrapper(x, i) = x[i...]
754+
755+
@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})
756+
757+
# PART OF HACK 2
758+
function change_origin(origin, arr)
759+
return origin(arr)
760+
end
761+
762+
@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin
763+
size = size(arr)
764+
eltype = eltype(arr)
765+
ndims = ndims(arr)
766+
end
767+
632768
function tearing(state::TearingState; kwargs...)
633769
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
634770
complete!(state.structure)
@@ -643,10 +779,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
643779
instead, which calls this function internally.
644780
"""
645781
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
646-
simplify = false, kwargs...)
782+
simplify = false, cse_hack = true, array_hack = true, kwargs...)
647783
var_eq_matching, full_var_eq_matching = tearing(state)
648784
invalidate_cache!(tearing_reassemble(
649-
state, var_eq_matching, full_var_eq_matching; mm, simplify))
785+
state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
650786
end
651787

652788
"""
@@ -668,7 +804,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
668804
the system is balanced.
669805
"""
670806
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
671-
mm = nothing, kwargs...)
807+
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
672808
jac = let state = state
673809
(eqs, vars) -> begin
674810
symeqs = EquationsView(state)[eqs]
@@ -692,5 +828,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
692828
end
693829
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
694830
kwargs...)
695-
tearing_reassemble(state, var_eq_matching; simplify, mm)
831+
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
696832
end

src/systems/nonlinear/initializesystem.jl

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ function generate_initializesystem(sys::ODESystem;
1212
algebraic_only = false,
1313
check_units = true, check_defguess = false,
1414
name = nameof(sys), kwargs...)
15-
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
15+
trueobs, eqs = unhack_observed(observed(sys), equations(sys))
16+
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
1617
vars_set = Set(vars) # for efficient in-lookup
1718

18-
eqs = equations(sys)
1919
idxs_diff = isdiffeq.(eqs)
2020
idxs_alge = .!idxs_diff
2121

@@ -24,7 +24,7 @@ function generate_initializesystem(sys::ODESystem;
2424
D = Differential(get_iv(sys))
2525
diffmap = merge(
2626
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
27-
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
27+
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
2828
)
2929

3030
# 1) process dummy derivatives and u0map into initialization system
@@ -166,15 +166,14 @@ function generate_initializesystem(sys::ODESystem;
166166
)
167167

168168
# 7) use observed equations for guesses of observed variables if not provided
169-
obseqs = observed(sys)
170-
for eq in obseqs
169+
for eq in trueobs
171170
haskey(defs, eq.lhs) && continue
172171
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
173172

174173
defs[eq.lhs] = eq.rhs
175174
end
176175

177-
eqs_ics = Symbolics.substitute.([eqs_ics; obseqs], (paramsubs,))
176+
eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,))
178177
vars = [vars; collect(values(paramsubs))]
179178
for k in keys(defs)
180179
defs[k] = substitute(defs[k], paramsubs)
@@ -324,3 +323,50 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
324323
return nothing, nothing, nothing, nothing
325324
end
326325
end
326+
327+
"""
328+
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
329+
initialization.
330+
"""
331+
function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
332+
subs = Dict()
333+
tempvars = Set()
334+
rm_idxs = Int[]
335+
for (i, eq) in enumerate(obseqs)
336+
iscall(eq.rhs) || continue
337+
if operation(eq.rhs) == StructuralTransformations.change_origin
338+
push!(rm_idxs, i)
339+
continue
340+
end
341+
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
342+
var, idxs = arguments(eq.rhs)
343+
subs[eq.rhs] = var[idxs...]
344+
push!(tempvars, var)
345+
end
346+
end
347+
348+
for (i, eq) in enumerate(eqs)
349+
iscall(eq.rhs) || continue
350+
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
351+
var, idxs = arguments(eq.rhs)
352+
subs[eq.rhs] = var[idxs...]
353+
push!(tempvars, var)
354+
end
355+
end
356+
357+
for (i, eq) in enumerate(obseqs)
358+
if eq.lhs in tempvars
359+
subs[eq.lhs] = eq.rhs
360+
push!(rm_idxs, i)
361+
end
362+
end
363+
364+
obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)]
365+
obseqs = map(obseqs) do eq
366+
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
367+
end
368+
eqs = map(eqs) do eq
369+
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
370+
end
371+
return obseqs, eqs
372+
end

src/utils.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,21 @@ function vars!(vars, O; op = Differential)
389389
f = getcalledparameter(O)
390390
push!(vars, f)
391391
for arg in arguments(O)
392-
vars!(vars, arg; op)
392+
if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
393+
for el in arg
394+
vars!(vars, unwrap(el); op)
395+
end
396+
else
397+
vars!(vars, arg; op)
398+
end
393399
end
394400
return vars
395401
end
396402
return push!(vars, O)
397403
end
398404
if symbolic_type(O) == NotSymbolic() && O isa AbstractArray
399405
for arg in O
400-
vars!(vars, arg; op)
406+
vars!(vars, unwrap(arg); op)
401407
end
402408
return vars
403409
end

test/odesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,14 @@ end
14481448
@test_nowarn ODESystem(Equation[], t; parameter_dependencies = [p ~ 1.0], name = :a)
14491449
end
14501450

1451+
@testset "Variable discovery in arrays of `Num` inside callable symbolic" begin
1452+
@variables x(t) y(t)
1453+
@parameters foo(::AbstractVector)
1454+
sys = @test_nowarn ODESystem(D(x) ~ foo([x, 2y]), t; name = :sys)
1455+
@test length(unknowns(sys)) == 2
1456+
@test any(isequal(y), unknowns(sys))
1457+
end
1458+
14511459
@testset "Inplace observed" begin
14521460
@variables x(t)
14531461
@parameters p[1:2] q

0 commit comments

Comments
 (0)