diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 7462122998..d88120afa0 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -54,7 +54,8 @@ import SCCNonlinearSolve using Reexport using RecursiveArrayTools import Graphs: SimpleDiGraph, add_edge!, incidence_matrix -import BlockArrays: BlockedArray, Block, blocksize, blocksizes +import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!, + undef_blocks, blocks import CommonSolve import EnumX diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index ccc749baeb..2bafe55396 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1919,7 +1919,7 @@ function Base.show( nrows > 0 && hint && print(io, " see hierarchy($name)") for i in 1:nrows sub = subs[i] - name = String(nameof(sub)) + local name = String(nameof(sub)) print(io, "\n ", name) desc = description(sub) if !isempty(desc) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index fe18d0de35..6b0b0cc759 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...) end +const TypeT = Union{DataType, UnionAll} + struct CacheWriter{F} fn::F end function (cw::CacheWriter)(p, sols) - cw.fn(p.caches[1], sols, p...) + cw.fn(p.caches, sols, p...) end -function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation}; +function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT}, + exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation}; eval_expression = false, eval_module = @__MODULE__) ps = parameters(sys) rps = reorder_parameters(sys, ps) obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs] cmap, cs = get_cmap(sys) cmap_assigns = [eq.lhs ← eq.rhs for eq in cmap] + + outsyms = [Symbol(:out, i) for i in eachindex(buffer_types)] + body = map(eachindex(buffer_types), buffer_types) do i, T + Symbol(:tmp, i) ← SetArray(true, :(out[$i]), get(exprs, T, [])) + end fn = Func( [:out, DestructuredArgs(DestructuredArgs.(solsyms)), DestructuredArgs.(rps)...], [], - SetArray(true, :out, exprs) + Let(body, :()) ) |> wrap_assignments(false, obs_assigns)[2] |> wrap_parameter_dependencies(sys, false)[2] |> - wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |> + wrap_array_vars(sys, []; dvs = nothing, inputs = [])[2] |> wrap_assignments(false, cmap_assigns)[2] |> toexpr return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module)) end @@ -677,8 +685,17 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, explicitfuns = [] nlfuns = [] - prevobsidxs = Int[] - cachesize = 0 + prevobsidxs = BlockArray(undef_blocks, Vector{Int}, Int[]) + # Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a + # dict to maintain a consistent order of buffers across SCCs + cachetypes = TypeT[] + cachesizes = Int[] + # explicitfun! related information for each SCC + # We need to compute buffer sizes before doing any codegen + scc_cachevars = Dict{TypeT, Vector{Any}}[] + scc_cacheexprs = Dict{TypeT, Vector{Any}}[] + scc_eqs = Vector{Equation}[] + scc_obs = Vector{Equation}[] for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs)) # subset unknowns and equations _dvs = dvs[vscc] @@ -690,11 +707,10 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, _obs = obs[obsidxs] # get all subexpressions in the RHS which we can precompute in the cache + # precomputed subexpressions should not contain `banned_vars` banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,)))) - for var in banned_vars - iscall(var) || continue - operation(var) === getindex || continue - push!(banned_vars, arguments(var)[1]) + filter!(banned_vars) do var + symbolic_type(var) != ArraySymbolic() || all(x -> var[i] in banned_vars, eachindex(var)) end state = Dict() for i in eachindex(_obs) @@ -706,37 +722,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, _eqs[i].rhs, banned_vars, state) end - # cached variables and their corresponding expressions - cachevars = Any[obs[i].lhs for i in prevobsidxs] - cacheexprs = Any[obs[i].lhs for i in prevobsidxs] + # map from symtype to cached variables and their expressions + cachevars = Dict{Union{DataType, UnionAll}, Vector{Any}}() + cacheexprs = Dict{Union{DataType, UnionAll}, Vector{Any}}() + # observed of previous SCCs are in the cache + # NOTE: When we get proper CSE, we can substitute these + # and then use `subexpressions_not_involving_vars!` + for i in prevobsidxs + T = symtype(obs[i].lhs) + buf = get!(() -> Any[], cachevars, T) + push!(buf, obs[i].lhs) + + buf = get!(() -> Any[], cacheexprs, T) + push!(buf, obs[i].lhs) + end + for (k, v) in state - push!(cachevars, unwrap(v)) - push!(cacheexprs, unwrap(k)) + k = unwrap(k) + v = unwrap(v) + T = symtype(k) + buf = get!(() -> Any[], cachevars, T) + push!(buf, v) + buf = get!(() -> Any[], cacheexprs, T) + push!(buf, k) end - cachesize = max(cachesize, length(cachevars)) + + # update the sizes of cache buffers + for (T, buf) in cachevars + idx = findfirst(isequal(T), cachetypes) + if idx === nothing + push!(cachetypes, T) + push!(cachesizes, 0) + idx = lastindex(cachetypes) + end + cachesizes[idx] = max(cachesizes[idx], length(buf)) + end + + push!(scc_cachevars, cachevars) + push!(scc_cacheexprs, cacheexprs) + push!(scc_eqs, _eqs) + push!(scc_obs, _obs) + blockpush!(prevobsidxs, obsidxs) + end + + for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs)) + _dvs = dvs[vscc] + _eqs = scc_eqs[i] + _prevobsidxs = reduce(vcat, blocks(prevobsidxs)[1:(i - 1)]; init = Int[]) + _obs = scc_obs[i] + cachevars = scc_cachevars[i] + cacheexprs = scc_cacheexprs[i] if isempty(cachevars) push!(explicitfuns, Returns(nothing)) else solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1))) push!(explicitfuns, - CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs]; + CacheWriter(sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs]; eval_expression, eval_module)) end + + cachebufsyms = Tuple(map(cachetypes) do T + get(cachevars, T, []) + end) f = SCCNonlinearFunction{iip}( - sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...) + sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, kwargs...) push!(nlfuns, f) - append!(cachevars, _dvs) - append!(cacheexprs, _dvs) - for i in obsidxs - push!(cachevars, obs[i].lhs) - push!(cacheexprs, obs[i].rhs) - end - append!(prevobsidxs, obsidxs) end - if cachesize != 0 - p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize)) + if !isempty(cachetypes) + templates = map(cachetypes, cachesizes) do T, n + # Real refers to `eltype(u0)` + if T == Real + T = eltype(u0) + elseif T <: Array && eltype(T) == Real + T = Array{eltype(u0), ndims(T)} + end + BufferTemplate(T, n) + end + p = rebuild_with_caches(p, templates...) end subprobs = [] diff --git a/src/utils.jl b/src/utils.jl index c3011c2a79..5e4f0b52d2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1108,23 +1108,33 @@ returns the modified `expr`. """ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) expr = unwrap(expr) - symbolic_type(expr) == NotSymbolic() && return expr + if symbolic_type(expr) == NotSymbolic() + if is_array_of_symbolics(expr) + return map(expr) do el + subexpressions_not_involving_vars!(el, vars, state) + end + end + return expr + end + any(isequal(expr), vars) && return expr iscall(expr) || return expr - is_variable_floatingpoint(expr) || return expr - symtype(expr) <: Union{Real, AbstractArray{<:Real}} || return expr Symbolics.shape(expr) == Symbolics.Unknown() && return expr haskey(state, expr) && return state[expr] - vs = ModelingToolkit.vars(expr) - intersect!(vs, vars) - if isempty(vs) + op = operation(expr) + args = arguments(expr) + # if this is a `getindex` and the getindex-ed value is a `Sym` + # or it is not a called parameter + # OR + # none of `vars` are involved in `expr` + if op === getindex && (issym(args[1]) || !iscalledparameter(args[1])) || + (vs = ModelingToolkit.vars(expr); intersect!(vs, vars); isempty(vs)) sym = gensym(:subexpr) stype = symtype(expr) var = similar_variable(expr, sym) state[expr] = var return var end - op = operation(expr) - args = arguments(expr) + if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic() indep_args = [] dep_args = [] @@ -1143,7 +1153,6 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) return op(indep_term, dep_term) end newargs = map(args) do arg - symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg subexpressions_not_involving_vars!(arg, vars, state) end return maketerm(typeof(expr), op, newargs, metadata(expr)) diff --git a/test/scc_nonlinear_problem.jl b/test/scc_nonlinear_problem.jl index fdf1646343..57f3d72fb7 100644 --- a/test/scc_nonlinear_problem.jl +++ b/test/scc_nonlinear_problem.jl @@ -161,3 +161,95 @@ end @test SciMLBase.successful_retcode(sccsol) @test val[] == 1 end + +import ModelingToolkitStandardLibrary.Blocks as B +import ModelingToolkitStandardLibrary.Mechanical.Translational as T +import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC + +@testset "Caching of subexpressions of different types" begin + liquid_pressure(rho, rho_0, bulk) = (rho / rho_0 - 1) * bulk + gas_pressure(rho, rho_0, p_gas, rho_gas) = rho * ((0 - p_gas) / (rho_0 - rho_gas)) + full_pressure(rho, rho_0, bulk, p_gas, rho_gas) = ifelse( + rho >= rho_0, liquid_pressure(rho, rho_0, bulk), + gas_pressure(rho, rho_0, p_gas, rho_gas)) + + @component function Volume(; + #parameters + area, + direction = +1, + x_int, + name) + pars = @parameters begin + area = area + x_int = x_int + rho_0 = 1000 + bulk = 1e9 + p_gas = -1000 + rho_gas = 1 + end + + vars = @variables begin + x(t) = x_int + dx(t), [guess = 0] + p(t), [guess = 0] + f(t), [guess = 0] + rho(t), [guess = 0] + m(t), [guess = 0] + dm(t), [guess = 0] + end + + systems = @named begin + port = IC.HydraulicPort() + flange = T.MechanicalPort() + end + + eqs = [ + # connectors + port.p ~ p + port.dm ~ dm + flange.v * direction ~ dx + flange.f * direction ~ -f + + # differentials + D(x) ~ dx + D(m) ~ dm + + # physics + p ~ full_pressure(rho, rho_0, bulk, p_gas, rho_gas) + f ~ p * area + m ~ rho * x * area] + + return ODESystem(eqs, t, vars, pars; name, systems) + end + + systems = @named begin + fluid = IC.HydraulicFluid(; bulk_modulus = 1e9) + + src1 = IC.Pressure(;) + src2 = IC.Pressure(;) + + vol1 = Volume(; area = 0.01, direction = +1, x_int = 0.1) + vol2 = Volume(; area = 0.01, direction = +1, x_int = 0.1) + + mass = T.Mass(; m = 10) + + sin1 = B.Sine(; frequency = 0.5, amplitude = +0.5e5, offset = 10e5) + sin2 = B.Sine(; frequency = 0.5, amplitude = -0.5e5, offset = 10e5) + end + + eqs = [connect(fluid, src1.port) + connect(fluid, src2.port) + connect(src1.port, vol1.port) + connect(src2.port, vol2.port) + connect(vol1.flange, mass.flange, vol2.flange) + connect(src1.p, sin1.output) + connect(src2.p, sin2.output)] + + initialization_eqs = [mass.s ~ 0.0 + mass.v ~ 0.0] + + @mtkbuild sys = ODESystem(eqs, t, [], []; systems, initialization_eqs) + prob = ODEProblem(sys, [], (0, 5)) + sol = solve(prob) + @test SciMLBase.successful_retcode(sol) +end