Skip to content

Commit 2467c69

Browse files
feat: cache subexpressions dependent only on previous SCCs
1 parent 9a3b275 commit 2467c69

File tree

3 files changed

+140
-6
lines changed

3 files changed

+140
-6
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,11 @@ function SCCNonlinearFunction{iip}(
579579
f(resid, u, p) = f_iip(resid, u, p)
580580
f(resid, u, p::MTKParameters) = f_iip(resid, u, p...)
581581

582-
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs, parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
582+
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs,
583+
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
583584
if get_index_cache(sys) !== nothing
584-
@set! subsys.index_cache = subset_unknowns_observed(get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
585+
@set! subsys.index_cache = subset_unknowns_observed(
586+
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
585587
@set! subsys.complete = true
586588
end
587589

@@ -620,8 +622,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
620622
explicitfuns = []
621623
nlfuns = []
622624
prevobsidxs = Int[]
623-
cachevars = []
624-
cacheexprs = []
625+
cachesize = 0
625626
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
626627
# subset unknowns and equations
627628
_dvs = dvs[vscc]
@@ -632,6 +633,32 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
632633
setdiff!(obsidxs, prevobsidxs)
633634
_obs = obs[obsidxs]
634635

636+
# get all subexpressions in the RHS which we can precompute in the cache
637+
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
638+
for var in banned_vars
639+
iscall(var) || continue
640+
operation(var) === getindex || continue
641+
push!(banned_vars, arguments(var)[1])
642+
end
643+
state = Dict()
644+
for i in eachindex(_obs)
645+
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
646+
_obs[i].rhs, banned_vars, state)
647+
end
648+
for i in eachindex(_eqs)
649+
_eqs[i] = _eqs[i].lhs ~ subexpressions_not_involving_vars!(
650+
_eqs[i].rhs, banned_vars, state)
651+
end
652+
653+
# cached variables and their corresponding expressions
654+
cachevars = Any[obs[i].lhs for i in prevobsidxs]
655+
cacheexprs = Any[obs[i].rhs for i in prevobsidxs]
656+
for (k, v) in state
657+
push!(cachevars, unwrap(v))
658+
push!(cacheexprs, unwrap(k))
659+
end
660+
cachesize = max(cachesize, length(cachevars))
661+
635662
if isempty(cachevars)
636663
push!(explicitfuns, Returns(nothing))
637664
else
@@ -651,7 +678,9 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
651678
append!(prevobsidxs, obsidxs)
652679
end
653680

654-
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(cachevars)))
681+
if cachesize != 0
682+
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize))
683+
end
655684

656685
subprobs = []
657686
for (f, vscc) in zip(nlfuns, var_sccs)

src/utils.jl

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,6 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ
569569
return nothing
570570
end
571571

572-
573572
function collect_var!(unknowns, parameters, var, iv; depth = 0)
574573
isequal(var, iv) && return nothing
575574
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
@@ -1002,6 +1001,11 @@ end
10021001
diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
10031002
lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)
10041003

1004+
"""
1005+
$(TYPEDSIGNATURES)
1006+
1007+
Check if `sym` represents a symbolic floating point number or array of such numbers.
1008+
"""
10051009
function is_variable_floatingpoint(sym)
10061010
sym = unwrap(sym)
10071011
T = symtype(sym)
@@ -1053,3 +1057,87 @@ function observed_equations_used_by(sys::AbstractSystem, exprs)
10531057
sort!(obsidxs)
10541058
return obsidxs
10551059
end
1060+
1061+
"""
1062+
$(TYPEDSIGNATURES)
1063+
1064+
Given an expression `expr`, return a dictionary mapping subexpressions of `expr` that do
1065+
not involve variables in `vars` to anonymous symbolic variables. Also return the modified
1066+
`expr` with the substitutions indicated by the dictionary. If `expr` is a function
1067+
of only `vars`, then all of the returned subexpressions can be precomputed.
1068+
1069+
Note that this will only process subexpressions floating point value. Additionally,
1070+
array variables must be passed in both scalarized and non-scalarized forms in `vars`.
1071+
"""
1072+
function subexpressions_not_involving_vars(expr, vars)
1073+
expr = unwrap(expr)
1074+
vars = map(unwrap, vars)
1075+
state = Dict()
1076+
newexpr = subexpressions_not_involving_vars!(expr, vars, state)
1077+
return state, newexpr
1078+
end
1079+
1080+
"""
1081+
$(TYPEDSIGNATURES)
1082+
1083+
Mutating version of `subexpressions_not_involving_vars` which writes to `state`. Only
1084+
returns the modified `expr`.
1085+
"""
1086+
function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
1087+
expr = unwrap(expr)
1088+
symbolic_type(expr) == NotSymbolic() && return expr
1089+
iscall(expr) || return expr
1090+
is_variable_floatingpoint(expr) || return expr
1091+
symtype(expr) <: Union{Real, AbstractArray{<:Real}} || return expr
1092+
Symbolics.shape(expr) == Symbolics.Unknown() && return expr
1093+
haskey(state, expr) && return state[expr]
1094+
vs = ModelingToolkit.vars(expr)
1095+
intersect!(vs, vars)
1096+
if isempty(vs)
1097+
sym = gensym(:subexpr)
1098+
stype = symtype(expr)
1099+
var = similar_variable(expr, sym)
1100+
state[expr] = var
1101+
return var
1102+
end
1103+
op = operation(expr)
1104+
args = arguments(expr)
1105+
if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic()
1106+
indep_args = []
1107+
dep_args = []
1108+
for arg in args
1109+
_vs = ModelingToolkit.vars(arg)
1110+
intersect!(_vs, vars)
1111+
if !isempty(_vs)
1112+
push!(dep_args, subexpressions_not_involving_vars!(arg, vars, state))
1113+
else
1114+
push!(indep_args, arg)
1115+
end
1116+
end
1117+
indep_term = reduce(op, indep_args; init = Int(op == (*)))
1118+
indep_term = subexpressions_not_involving_vars!(indep_term, vars, state)
1119+
dep_term = reduce(op, dep_args; init = Int(op == (*)))
1120+
return op(indep_term, dep_term)
1121+
end
1122+
newargs = map(args) do arg
1123+
symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg
1124+
subexpressions_not_involving_vars!(arg, vars, state)
1125+
end
1126+
return maketerm(typeof(expr), op, newargs, metadata(expr))
1127+
end
1128+
1129+
"""
1130+
$(TYPEDSIGNATURES)
1131+
1132+
Create an anonymous symbolic variable of the same shape, size and symtype as `var`, with
1133+
name `gensym(name)`. Does not support unsized array symbolics.
1134+
"""
1135+
function similar_variable(var::BasicSymbolic, name = :anon)
1136+
name = gensym(name)
1137+
stype = symtype(var)
1138+
sym = Symbolics.variable(name; T = stype)
1139+
if size(var) !== ()
1140+
sym = setmetadata(sym, Symbolics.ArrayShapeCtx, map(Base.OneTo, size(var)))
1141+
end
1142+
return sym
1143+
end

test/scc_nonlinear_problem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,20 @@ end
143143
@test sol.usccsol.u atol=1e-10
144144
end
145145

146+
@testset "Expression caching" begin
147+
@variables x[1:4] = rand(4)
148+
val = Ref(0)
149+
function func(x, y)
150+
val[] += 1
151+
x + y
152+
end
153+
@register_symbolic func(x, y)
154+
@mtkbuild sys = NonlinearSystem([0 ~ x[1]^3 + x[2]^3 - 5
155+
0 ~ sin(x[1] - x[2]) - 0.5
156+
0 ~ func(x[1], x[2]) * exp(x[3]) - x[4]^3 - 5
157+
0 ~ func(x[1], x[2]) * exp(x[4]) - x[3]^3 - 4])
158+
sccprob = SCCNonlinearProblem(sys, [])
159+
sccsol = solve(sccprob, NewtonRaphson())
160+
@test SciMLBase.successful_retcode(sccsol)
161+
@test val[] == 1
162+
end

0 commit comments

Comments
 (0)