Skip to content

Commit 46a6b60

Browse files
feat: cache subexpressions dependent only on previous SCCs
1 parent 1f7851c commit 46a6b60

File tree

3 files changed

+120
-5
lines changed

3 files changed

+120
-5
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,11 @@ function SCCNonlinearFunction{iip}(
579579
f(resid, u, p) = f_iip(resid, u, p)
580580
f(resid, u, p::MTKParameters) = f_iip(resid, u, p...)
581581

582-
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs, parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
582+
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs,
583+
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
583584
if get_index_cache(sys) !== nothing
584-
@set! subsys.index_cache = subset_unknowns_observed(get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
585+
@set! subsys.index_cache = subset_unknowns_observed(
586+
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
585587
@set! subsys.complete = true
586588
end
587589

@@ -620,8 +622,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
620622
explicitfuns = []
621623
nlfuns = []
622624
prevobsidxs = Int[]
623-
cachevars = []
624-
cacheexprs = []
625+
cachesize = 0
625626
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
626627
# subset unknowns and equations
627628
_dvs = dvs[vscc]
@@ -632,6 +633,26 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
632633
setdiff!(obsidxs, prevobsidxs)
633634
_obs = obs[obsidxs]
634635

636+
# get all subexpressions in the RHS which we can precompute in the cache
637+
state = Dict()
638+
for i in eachindex(_obs)
639+
_obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!(
640+
_obs[i].rhs, _dvs, state)
641+
end
642+
for i in eachindex(_eqs)
643+
_eqs[i] = _eqs[i].lhs ~ subexpressions_not_involving_vars!(
644+
_eqs[i].rhs, _dvs, state)
645+
end
646+
647+
# cached variables and their corresponding expressions
648+
cachevars = [obs[i].lhs for i in prevobsidxs]
649+
cacheexprs = [obs[i].rhs for i in prevobsidxs]
650+
for (k, v) in state
651+
push!(cachevars, v)
652+
push!(cacheexprs, k)
653+
end
654+
cachesize = max(cachesize, length(cachevars))
655+
635656
if isempty(cachevars)
636657
push!(explicitfuns, Returns(nothing))
637658
else
@@ -651,7 +672,9 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
651672
append!(prevobsidxs, obsidxs)
652673
end
653674

654-
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(cachevars)))
675+
if cachesize != 0
676+
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize))
677+
end
655678

656679
subprobs = []
657680
for (f, vscc) in zip(nlfuns, var_sccs)

src/utils.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,3 +1042,78 @@ function observed_equations_used_by(sys::AbstractSystem, exprs)
10421042
sort!(obsidxs)
10431043
return obsidxs
10441044
end
1045+
1046+
"""
1047+
$(TYPEDSIGNATURES)
1048+
1049+
Given an expression `expr`, return a dictionary mapping subexpressions of `expr` that do
1050+
not involve variables in `vars` to anonymous symbolic variables. Also return the modified
1051+
`expr` with the substitutions indicated by the dictionary. If `expr` is a function
1052+
of only `vars`, then all of the returned subexpressions can be precomputed.
1053+
"""
1054+
function subexpressions_not_involving_vars(expr, vars)
1055+
expr = unwrap(expr)
1056+
vars = map(unwrap, vars)
1057+
state = Dict()
1058+
newexpr = subexpressions_not_involving_vars!(expr, vars, state)
1059+
return state, newexpr
1060+
end
1061+
1062+
"""
1063+
$(TYPEDSIGNATURES)
1064+
1065+
Mutating version of `subexpressions_not_involving_vars` which writes to `state`. Only
1066+
returns the modified `expr`.
1067+
"""
1068+
function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
1069+
expr = unwrap(expr)
1070+
symbolic_type(expr) == NotSymbolic() && return expr
1071+
iscall(expr) || return expr
1072+
Symbolics.shape(expr) == Symbolics.Unknown() && return expr
1073+
haskey(state, expr) && return state[expr]
1074+
if !any(x -> occursin(x, expr), vars)
1075+
sym = gensym(:subexpr)
1076+
stype = symtype(expr)
1077+
var = similar_variable(expr, sym)
1078+
state[expr] = var
1079+
return var
1080+
end
1081+
op = operation(expr)
1082+
args = arguments(expr)
1083+
if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic()
1084+
indep_args = []
1085+
dep_args = []
1086+
for arg in args
1087+
if any(x -> occursin(x, arg), vars)
1088+
push!(dep_args, subexpressions_not_involving_vars!(arg, vars, state))
1089+
else
1090+
push!(indep_args, arg)
1091+
end
1092+
end
1093+
indep_term = reduce(op, indep_args; init = Int(op == (*)))
1094+
indep_term = subexpressions_not_involving_vars!(indep_term, vars, state)
1095+
dep_term = reduce(op, dep_args; init = Int(op == (*)))
1096+
return op(indep_term, dep_term)
1097+
end
1098+
newargs = map(args) do arg
1099+
symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg
1100+
subexpressions_not_involving_vars!(arg, vars, state)
1101+
end
1102+
return maketerm(typeof(expr), op, newargs, metadata(expr))
1103+
end
1104+
1105+
"""
1106+
$(TYPEDSIGNATURES)
1107+
1108+
Create an anonymous symbolic variable of the same shape, size and symtype as `var`, with
1109+
name `gensym(name)`. Does not support unsized array symbolics.
1110+
"""
1111+
function similar_variable(var::BasicSymbolic, name = :anon)
1112+
name = gensym(name)
1113+
stype = symtype(var)
1114+
sym = Symbolics.variable(name; T = stype)
1115+
if size(var) !== ()
1116+
sym = setmetadata(sym, Symbolics.ArrayShapeCtx, map(Base.OneTo, size(var)))
1117+
end
1118+
return sym
1119+
end

test/scc_nonlinear_problem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,20 @@ end
142142
@test sol.usccsol.u atol=1e-10
143143
end
144144

145+
@testset "Expression caching" begin
146+
@variables x[1:4] = rand(4)
147+
val = Ref(0)
148+
function func(x, y)
149+
val[] += 1
150+
x + y
151+
end
152+
@register_symbolic func(x, y)
153+
@mtkbuild sys = NonlinearSystem([0 ~ x[1]^3 + x[2]^3 - 5
154+
0 ~ sin(x[1] - x[2]) - 0.5
155+
0 ~ func(x[1], x[2]) * exp(x[3]) - x[4]^3 - 5
156+
0 ~ func(x[1], x[2]) * exp(x[4]) - x[3]^3 - 4])
157+
sccprob = SCCNonlinearProblem(sys, [])
158+
sccsol = solve(sccprob, NewtonRaphson())
159+
@test SciMLBase.successful_retcode(sccsol)
160+
@test val[] == 1
161+
end

0 commit comments

Comments
 (0)