From a03f061f049b8cab15843b91acc6f2530f1a9c7d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Oct 2024 18:52:33 +0530 Subject: [PATCH 1/7] fix: fix variable discovery in arrays of `Num` passed to callable params --- src/utils.jl | 10 ++++++++-- test/odesystem.jl | 8 ++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index d2e8a3ea38..830ec98e44 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -389,7 +389,13 @@ function vars!(vars, O; op = Differential) f = getcalledparameter(O) push!(vars, f) for arg in arguments(O) - vars!(vars, arg; op) + if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray + for el in arg + vars!(vars, unwrap(el); op) + end + else + vars!(vars, arg; op) + end end return vars end @@ -397,7 +403,7 @@ function vars!(vars, O; op = Differential) end if symbolic_type(O) == NotSymbolic() && O isa AbstractArray for arg in O - vars!(vars, arg; op) + vars!(vars, unwrap(arg); op) end return vars end diff --git a/test/odesystem.jl b/test/odesystem.jl index 9446d105e0..cbc24300cd 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1447,3 +1447,11 @@ end @parameters p @test_nowarn ODESystem(Equation[], t; parameter_dependencies = [p ~ 1.0], name = :a) end + +@testset "Variable discovery in arrays of `Num` inside callable symbolic" begin + @variables x(t) y(t) + @parameters foo(::AbstractVector) + sys = @test_nowarn ODESystem(D(x) ~ foo([x, 2y]), t; name = :sys) + @test length(unknowns(sys)) == 2 + @test any(isequal(y), unknowns(sys)) +end From 5c2edf4ff43796be3284a42843fc73d75967bc58 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Oct 2024 18:53:17 +0530 Subject: [PATCH 2/7] fix: improve hack supporting unscalarized usage of array observed variables --- .../symbolics_tearing.jl | 98 +++++++++++++------ test/structural_transformation/utils.jl | 14 +++ 2 files changed, 83 insertions(+), 29 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 84cee928cd..889b75c611 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -574,48 +574,78 @@ function tearing_reassemble(state::TearingState, var_eq_matching, # TODO: compute the dependency correctly so that we don't have to do this obs = [fast_substitute(observed(sys), obs_sub); subeqs] - # HACK: Substitute non-scalarized symbolic arrays of observed variables - # E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations - # ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled - # by the topological sorting and dependency identification pieces - obs_arr_subs = Dict() + unknowns = Any[v + for (i, v) in enumerate(fullvars) + if diff_to_var[i] === nothing && ispresent(i)] + if !isempty(extra_vars) + for v in extra_vars + push!(unknowns, old_fullvars[v]) + end + end + @set! sys.unknowns = unknowns + # HACK: Add equations for array observed variables. If `p[i] ~ (...)` + # are equations, add an equation `p ~ [p[1], p[2], ...]` + # allow topsort to reorder them + # only add the new equation if all `p[i]` are present and the unscalarized + # form is used in any equation (observed or not) + # we first count the number of times the scalarized form of each observed + # variable occurs in observed equations (and unknowns if it's split). + + # map of array observed variable (unscalarized) to number of its + # scalarized terms that appear in observed equations + arr_obs_occurrences = Dict() + # to check if array variables occur in unscalarized form anywhere + all_vars = Set() for eq in obs + vars!(all_vars, eq.rhs) lhs = eq.lhs iscall(lhs) || continue operation(lhs) === getindex || continue - Symbolics.shape(lhs) !== Symbolics.Unknown() || continue + Symbolics.shape(lhs) != Symbolics.Unknown() || continue arg1 = arguments(lhs)[1] - haskey(obs_arr_subs, arg1) && continue - obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)] # e.g. p => [p[1], p[2]] - index_first = eachindex(arg1)[1] - + cnt = get(arr_obs_occurrences, arg1, 0) + arr_obs_occurrences[arg1] = cnt + 1 + continue + end + # count variables in unknowns if they are scalarized forms of variables + # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)` + # is an observed equation. + for sym in unknowns + iscall(sym) || continue + operation(sym) === getindex || continue + Symbolics.shape(sym) != Symbolics.Unknown() || continue + arg1 = arguments(sym)[1] + cnt = get(arr_obs_occurrences, arg1, 0) + cnt == 0 && continue + arr_obs_occurrences[arg1] = cnt + 1 + end + for eq in neweqs + vars!(all_vars, eq.rhs) + end + obs_arr_eqs = Equation[] + for (arrvar, cnt) in arr_obs_occurrences + cnt == length(arrvar) || continue + arrvar in all_vars || continue + # firstindex returns 1 for multidimensional array symbolics + firstind = first(eachindex(arrvar)) + scal = [arrvar[i] for i in eachindex(arrvar)] # respect non-1-indexed arrays # TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency - obs_arr_subs[arg1] = Origin(index_first)(obs_arr_subs[arg1]) - end - for i in eachindex(neweqs) - neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator) - end - for i in eachindex(obs) - obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator) - end - for i in eachindex(subeqs) - subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator) + # `change_origin` is required because `Origin(firstind)(scal)` makes codegen + # try to `create_array(OffsetArray{...}, ...)` which errors. + # `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size` + # of `scal`. + push!(obs_arr_eqs, arrvar ~ change_origin(Origin(firstind), scal)) end + append!(obs, obs_arr_eqs) + append!(subeqs, obs_arr_eqs) + # need to re-sort subeqs + subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) @set! sys.eqs = neweqs @set! sys.observed = obs - unknowns = Any[v - for (i, v) in enumerate(fullvars) - if diff_to_var[i] === nothing && ispresent(i)] - if !isempty(extra_vars) - for v in extra_vars - push!(unknowns, old_fullvars[v]) - end - end - @set! sys.unknowns = unknowns @set! sys.substitutions = Substitutions(subeqs, deps) # Only makes sense for time-dependent @@ -629,6 +659,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching, return invalidate_cache!(sys) end +function change_origin(origin, arr) + return origin(arr) +end + +@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin + size = size(arr) + eltype = eltype(arr) + ndims = ndims(arr) +end + function tearing(state::TearingState; kwargs...) state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...) complete!(state.structure) diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 8644d96945..c7146bab65 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -40,3 +40,17 @@ end @test ModelingToolkit.đť‘ neighbors(g, 1) == [2] @test ModelingToolkit.đť‘‘neighbors(g, 2) == [1] end + +@testset "array observed used unscalarized in another observed" begin + @variables x(t) y(t)[1:2] z(t)[1:2] + @parameters foo(::AbstractVector)[1:2] + _tmp_fn(x) = 2x + @mtkbuild sys = ODESystem( + [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) + @test length(equations(sys)) == 1 + @test length(observed(sys)) == 6 + @test any(eq -> isequal(eq.lhs, y), observed(sys)) + @test any(eq -> isequal(eq.lhs, z), observed(sys)) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) +end From 6c3576f0f18918fb5643fdedc4d9842eede18bc2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 18 Oct 2024 20:33:59 +0530 Subject: [PATCH 3/7] feat: add simple CSE for array scalarization case --- .../symbolics_tearing.jl | 54 ++++++++++++++++++- test/structural_transformation/utils.jl | 18 ++++++- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 889b75c611..d7bec888b5 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end @set! sys.unknowns = unknowns + # HACK: Since we don't support array equations, any equation of the sort + # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly + # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically + # for this case to avoid the unnecessary cost. + # This and the below hack are implemented simultaneously + + # mapping of rhs to temporary CSE variable + # `f(...) => tmpvar` in above example + rhs_to_tempvar = Dict() + # HACK: Add equations for array observed variables. If `p[i] ~ (...)` # are equations, add an equation `p ~ [p[1], p[2], ...]` # allow topsort to reorder them @@ -597,9 +607,42 @@ function tearing_reassemble(state::TearingState, var_eq_matching, arr_obs_occurrences = Dict() # to check if array variables occur in unscalarized form anywhere all_vars = Set() - for eq in obs - vars!(all_vars, eq.rhs) + for (i, eq) in enumerate(obs) lhs = eq.lhs + rhs = eq.rhs + vars!(all_vars, rhs) + + # HACK 1 + if (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && + iscall(rhs) && operation(rhs) === getindex && + Symbolics.shape(rhs) != Symbolics.Unknown() + rhs_arr = arguments(rhs)[1] + if !haskey(rhs_to_tempvar, rhs_arr) + tempvar = gensym(Symbol(lhs)) + N = length(rhs_arr) + tempvar = unwrap(Symbolics.variable( + tempvar; T = Symbolics.symtype(rhs_arr))) + tempvar = setmetadata( + tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) + tempeq = tempvar ~ rhs_arr + rhs_to_tempvar[rhs_arr] = tempvar + push!(obs, tempeq) + push!(subeqs, tempeq) + end + + # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different, + # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr` + # which fails the topological sort + neweq = lhs ~ getindex_wrapper( + rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) + obs[i] = neweq + subeqi = findfirst(isequal(eq), subeqs) + if subeqi !== nothing + subeqs[subeqi] = neweq + end + end + # end HACK 1 + iscall(lhs) || continue operation(lhs) === getindex || continue Symbolics.shape(lhs) != Symbolics.Unknown() || continue @@ -640,6 +683,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end append!(obs, obs_arr_eqs) append!(subeqs, obs_arr_eqs) + # need to re-sort subeqs subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) @@ -659,6 +703,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching, return invalidate_cache!(sys) end +# PART OF HACK 1 +getindex_wrapper(x, i) = x[i...] + +@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}}) + +# PART OF HACK 2 function change_origin(origin, arr) return origin(arr) end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index c7146bab65..dded8333f2 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -48,9 +48,25 @@ end @mtkbuild sys = ODESystem( [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) @test length(equations(sys)) == 1 - @test length(observed(sys)) == 6 + @test length(observed(sys)) == 7 @test any(eq -> isequal(eq.lhs, y), observed(sys)) @test any(eq -> isequal(eq.lhs, z), observed(sys)) prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn]) @test_nowarn prob.f(prob.u0, prob.p, 0.0) end + +@testset "scalarized array observed calling same function multiple times" begin + @variables x(t) y(t)[1:2] + @parameters foo(::Real)[1:2] + val = Ref(0) + function _tmp_fn2(x) + val[] += 1 + return [x, 2x] + end + @mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t) + @test length(equations(sys)) == 1 + @test length(observed(sys)) == 3 + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 1 +end From c023e7ebba82d5d83b7f6448ad04a67f37fccefa Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 23 Oct 2024 20:17:18 +0530 Subject: [PATCH 4/7] fix: undo the hack in `generate_initializesystem` --- src/systems/nonlinear/initializesystem.jl | 45 ++++++++++++++++++++--- test/structural_transformation/utils.jl | 16 ++++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index becace8ec5..0bdb770c54 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -12,7 +12,9 @@ function generate_initializesystem(sys::ODESystem; algebraic_only = false, check_units = true, check_defguess = false, name = nameof(sys), kwargs...) - vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)]) + trueobs = unhack_observed(observed(sys)) + @show trueobs + vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup eqs = equations(sys) @@ -24,7 +26,7 @@ function generate_initializesystem(sys::ODESystem; D = Differential(get_iv(sys)) diffmap = merge( Dict(eq.lhs => eq.rhs for eq in eqs_diff), - Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys)) + Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs) ) # 1) process dummy derivatives and u0map into initialization system @@ -166,15 +168,14 @@ function generate_initializesystem(sys::ODESystem; ) # 7) use observed equations for guesses of observed variables if not provided - obseqs = observed(sys) - for eq in obseqs + for eq in trueobs haskey(defs, eq.lhs) && continue any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue defs[eq.lhs] = eq.rhs end - eqs_ics = Symbolics.substitute.([eqs_ics; obseqs], (paramsubs,)) + eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,)) vars = [vars; collect(values(paramsubs))] for k in keys(defs) defs[k] = substitute(defs[k], paramsubs) @@ -324,3 +325,37 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) return nothing, nothing, nothing, nothing end end + +""" +Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with +initialization. +""" +function unhack_observed(eqs::Vector{Equation}) + subs = Dict() + tempvars = Set() + rm_idxs = Int[] + for (i, eq) in enumerate(eqs) + iscall(eq.rhs) || continue + if operation(eq.rhs) == StructuralTransformations.change_origin + push!(rm_idxs, i) + continue + end + if operation(eq.rhs) == StructuralTransformations.getindex_wrapper + var, idxs = arguments(eq.rhs) + subs[eq.rhs] = var[idxs...] + push!(tempvars, var) + end + end + + for (i, eq) in enumerate(eqs) + if eq.lhs in tempvars + subs[eq.lhs] = eq.rhs + push!(rm_idxs, i) + end + end + + eqs = eqs[setdiff(eachindex(eqs), rm_idxs)] + return map(eqs) do eq + fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + end +end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index dded8333f2..ea3552ff0d 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -53,6 +53,14 @@ end @test any(eq -> isequal(eq.lhs, z), observed(sys)) prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn]) @test_nowarn prob.f(prob.u0, prob.p, 0.0) + + isys = ModelingToolkit.generate_initializesystem(sys) + @test length(unknowns(isys)) == 5 + @test length(equations(isys)) == 4 + @test !any(equations(isys)) do eq + iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, + StructuralTransformations.change_origin] + end end @testset "scalarized array observed calling same function multiple times" begin @@ -69,4 +77,12 @@ end prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) @test_nowarn prob.f(prob.u0, prob.p, 0.0) @test val[] == 1 + + isys = ModelingToolkit.generate_initializesystem(sys) + @test length(unknowns(isys)) == 3 + @test length(equations(isys)) == 2 + @test !any(equations(isys)) do eq + iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, + StructuralTransformations.change_origin] + end end From 79a7fc933b858e51412abc486ce1018fb8484638 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 24 Oct 2024 04:51:51 +0000 Subject: [PATCH 5/7] Update src/systems/nonlinear/initializesystem.jl --- src/systems/nonlinear/initializesystem.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 0bdb770c54..6c7457e49b 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -13,7 +13,6 @@ function generate_initializesystem(sys::ODESystem; check_units = true, check_defguess = false, name = nameof(sys), kwargs...) trueobs = unhack_observed(observed(sys)) - @show trueobs vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup From d2fe0eb1a482ec7dc28b39bc21c7367bc1b60f8e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 11:43:38 +0530 Subject: [PATCH 6/7] feat: allow CSE and array hacks to be disabled --- .../symbolics_tearing.jl | 79 +++++++++++-------- test/structural_transformation/utils.jl | 44 +++++++++++ 2 files changed, 90 insertions(+), 33 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index d7bec888b5..cabf0415ea 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -230,7 +230,7 @@ end =# function tearing_reassemble(state::TearingState, var_eq_matching, - full_var_eq_matching = nothing; simplify = false, mm = nothing) + full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true) @unpack fullvars, sys, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure extra_vars = Int[] @@ -584,24 +584,48 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end @set! sys.unknowns = unknowns - # HACK: Since we don't support array equations, any equation of the sort - # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly - # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically - # for this case to avoid the unnecessary cost. - # This and the below hack are implemented simultaneously + obs, subeqs = cse_and_array_hacks( + obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack) + @set! sys.eqs = neweqs + @set! sys.observed = obs + + @set! sys.substitutions = Substitutions(subeqs, deps) + + # Only makes sense for time-dependent + # TODO: generalize to SDE + if sys isa ODESystem + @set! sys.schedule = Schedule(var_eq_matching, dummy_sub) + end + sys = schedule(sys) + @set! state.sys = sys + @set! sys.tearing_state = state + return invalidate_cache!(sys) +end + +""" +# HACK 1 + +Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]` +gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets +_very_ expensive. this hack performs a limited form of CSE specifically for this case to +avoid the unnecessary cost. This and the below hack are implemented simultaneously + +# HACK 2 + +Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an +equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation +if all `p[i]` are present and the unscalarized form is used in any equation (observed or +not) we first count the number of times the scalarized form of each observed variable +occurs in observed equations (and unknowns if it's split). +""" +function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true) + # HACK 1 # mapping of rhs to temporary CSE variable # `f(...) => tmpvar` in above example rhs_to_tempvar = Dict() - # HACK: Add equations for array observed variables. If `p[i] ~ (...)` - # are equations, add an equation `p ~ [p[1], p[2], ...]` - # allow topsort to reorder them - # only add the new equation if all `p[i]` are present and the unscalarized - # form is used in any equation (observed or not) - # we first count the number of times the scalarized form of each observed - # variable occurs in observed equations (and unknowns if it's split). - + # HACK 2 # map of array observed variable (unscalarized) to number of its # scalarized terms that appear in observed equations arr_obs_occurrences = Dict() @@ -613,7 +637,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching, vars!(all_vars, rhs) # HACK 1 - if (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && + if cse && + (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && iscall(rhs) && operation(rhs) === getindex && Symbolics.shape(rhs) != Symbolics.Unknown() rhs_arr = arguments(rhs)[1] @@ -643,6 +668,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end # end HACK 1 + array || continue iscall(lhs) || continue operation(lhs) === getindex || continue Symbolics.shape(lhs) != Symbolics.Unknown() || continue @@ -687,20 +713,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, # need to re-sort subeqs subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) - @set! sys.eqs = neweqs - @set! sys.observed = obs - - @set! sys.substitutions = Substitutions(subeqs, deps) - - # Only makes sense for time-dependent - # TODO: generalize to SDE - if sys isa ODESystem - @set! sys.schedule = Schedule(var_eq_matching, dummy_sub) - end - sys = schedule(sys) - @set! state.sys = sys - @set! sys.tearing_state = state - return invalidate_cache!(sys) + return obs, subeqs end # PART OF HACK 1 @@ -733,10 +746,10 @@ new residual equations after tearing. End users are encouraged to call [`structu instead, which calls this function internally. """ function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing, - simplify = false, kwargs...) + simplify = false, cse_hack = true, array_hack = true, kwargs...) var_eq_matching, full_var_eq_matching = tearing(state) invalidate_cache!(tearing_reassemble( - state, var_eq_matching, full_var_eq_matching; mm, simplify)) + state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack)) end """ @@ -758,7 +771,7 @@ Perform index reduction and use the dummy derivative technique to ensure that the system is balanced. """ function dummy_derivative(sys, state = TearingState(sys); simplify = false, - mm = nothing, kwargs...) + mm = nothing, cse_hack = true, array_hack = true, kwargs...) jac = let state = state (eqs, vars) -> begin symeqs = EquationsView(state)[eqs] @@ -782,5 +795,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false, end var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, kwargs...) - tearing_reassemble(state, var_eq_matching; simplify, mm) + tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack) end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index ea3552ff0d..04600a7a6b 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -86,3 +86,47 @@ end StructuralTransformations.change_origin] end end + +@testset "array and cse hacks can be disabled" begin + @testset "fully_determined = true" begin + @variables x(t) y(t)[1:2] z(t)[1:2] + @parameters foo(::AbstractVector)[1:2] + _tmp_fn(x) = 2x + @named sys = ODESystem( + [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) + + sys1 = structural_simplify(sys; cse_hack = false) + @test length(observed(sys1)) == 6 + @test !any(observed(sys1)) do eq + iscall(eq.rhs) && + operation(eq.rhs) == StructuralTransformations.getindex_wrapper + end + + sys2 = structural_simplify(sys; array_hack = false) + @test length(observed(sys2)) == 5 + @test !any(observed(sys2)) do eq + iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin + end + end + + @testset "fully_determined = false" begin + @variables x(t) y(t)[1:2] z(t)[1:2] w(t) + @parameters foo(::AbstractVector)[1:2] + _tmp_fn(x) = 2x + @named sys = ODESystem( + [D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) + + sys1 = structural_simplify(sys; cse_hack = false, fully_determined = false) + @test length(observed(sys1)) == 6 + @test !any(observed(sys1)) do eq + iscall(eq.rhs) && + operation(eq.rhs) == StructuralTransformations.getindex_wrapper + end + + sys2 = structural_simplify(sys; array_hack = false, fully_determined = false) + @test length(observed(sys2)) == 5 + @test !any(observed(sys2)) do eq + iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin + end + end +end From 10253638d71ad9101f113b0dcaf1828e43615541 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 20:16:14 +0530 Subject: [PATCH 7/7] feat: extend CSE hack to non-observed equations --- .../symbolics_tearing.jl | 45 ++++++++++++++++--- src/systems/nonlinear/initializesystem.jl | 24 +++++++--- test/structural_transformation/utils.jl | 22 +++++++++ 3 files changed, 79 insertions(+), 12 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index cabf0415ea..a854acb9b1 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -584,7 +584,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end @set! sys.unknowns = unknowns - obs, subeqs = cse_and_array_hacks( + obs, subeqs, deps = cse_and_array_hacks( obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack) @set! sys.eqs = neweqs @@ -637,10 +637,7 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = vars!(all_vars, rhs) # HACK 1 - if cse && - (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && - iscall(rhs) && operation(rhs) === getindex && - Symbolics.shape(rhs) != Symbolics.Unknown() + if cse && is_getindexed_array(rhs) rhs_arr = arguments(rhs)[1] if !haskey(rhs_to_tempvar, rhs_arr) tempvar = gensym(Symbol(lhs)) @@ -677,6 +674,33 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = arr_obs_occurrences[arg1] = cnt + 1 continue end + + # Also do CSE for `equations(sys)` + if cse + for (i, eq) in enumerate(neweqs) + (; lhs, rhs) = eq + is_getindexed_array(rhs) || continue + rhs_arr = arguments(rhs)[1] + if !haskey(rhs_to_tempvar, rhs_arr) + tempvar = gensym(Symbol(lhs)) + N = length(rhs_arr) + tempvar = unwrap(Symbolics.variable( + tempvar; T = Symbolics.symtype(rhs_arr))) + tempvar = setmetadata( + tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) + tempeq = tempvar ~ rhs_arr + rhs_to_tempvar[rhs_arr] = tempvar + push!(obs, tempeq) + push!(subeqs, tempeq) + end + # don't need getindex_wrapper, but do it anyway to know that this + # hack took place + neweq = lhs ~ getindex_wrapper( + rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) + neweqs[i] = neweq + end + end + # count variables in unknowns if they are scalarized forms of variables # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)` # is an observed equation. @@ -713,7 +737,16 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = # need to re-sort subeqs subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) - return obs, subeqs + deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1)) + for i in 1:length(subeqs)] + + return obs, subeqs, deps +end + +function is_getindexed_array(rhs) + (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && + iscall(rhs) && operation(rhs) === getindex && + Symbolics.shape(rhs) != Symbolics.Unknown() end # PART OF HACK 1 diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 6c7457e49b..eefe393acc 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -12,11 +12,10 @@ function generate_initializesystem(sys::ODESystem; algebraic_only = false, check_units = true, check_defguess = false, name = nameof(sys), kwargs...) - trueobs = unhack_observed(observed(sys)) + trueobs, eqs = unhack_observed(observed(sys), equations(sys)) vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup - eqs = equations(sys) idxs_diff = isdiffeq.(eqs) idxs_alge = .!idxs_diff @@ -329,11 +328,11 @@ end Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with initialization. """ -function unhack_observed(eqs::Vector{Equation}) +function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) subs = Dict() tempvars = Set() rm_idxs = Int[] - for (i, eq) in enumerate(eqs) + for (i, eq) in enumerate(obseqs) iscall(eq.rhs) || continue if operation(eq.rhs) == StructuralTransformations.change_origin push!(rm_idxs, i) @@ -347,14 +346,27 @@ function unhack_observed(eqs::Vector{Equation}) end for (i, eq) in enumerate(eqs) + iscall(eq.rhs) || continue + if operation(eq.rhs) == StructuralTransformations.getindex_wrapper + var, idxs = arguments(eq.rhs) + subs[eq.rhs] = var[idxs...] + push!(tempvars, var) + end + end + + for (i, eq) in enumerate(obseqs) if eq.lhs in tempvars subs[eq.lhs] = eq.rhs push!(rm_idxs, i) end end - eqs = eqs[setdiff(eachindex(eqs), rm_idxs)] - return map(eqs) do eq + obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)] + obseqs = map(obseqs) do eq + fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + end + eqs = map(eqs) do eq fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) end + return obseqs, eqs end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 04600a7a6b..2704559f72 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -85,6 +85,28 @@ end iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, StructuralTransformations.change_origin] end + + @testset "CSE hack in equations(sys)" begin + val[] = 0 + @variables z(t)[1:2] + @mtkbuild sys = ODESystem( + [D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t) + @test length(equations(sys)) == 5 + @test length(observed(sys)) == 2 + prob = ODEProblem( + sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 2 + + isys = ModelingToolkit.generate_initializesystem(sys) + @test length(unknowns(isys)) == 5 + @test length(equations(isys)) == 2 + @test !any(equations(isys)) do eq + iscall(eq.rhs) && + operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, + StructuralTransformations.change_origin] + end + end end @testset "array and cse hacks can be disabled" begin