Skip to content

Commit 80656a0

Browse files
committed
add optional caching to rendering
improves performance from quadratic scaling to linear (but with a worse factor)
1 parent 42825b6 commit 80656a0

File tree

1 file changed

+82
-1
lines changed

1 file changed

+82
-1
lines changed

ext/Render.jl

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ function recursive_extract(sol, s, depth=0)
4040
end
4141
end
4242

43-
4443
function getu(sol::FakeSol, syms)
4544
t->[recursive_extract(sol, s) for s in syms]
4645
end
@@ -75,6 +74,81 @@ function Base.getproperty(sol::FakeSol, s::Symbol)
7574
end
7675
end
7776

77+
mutable struct CacheSol
78+
model
79+
sol
80+
cache::Dict{Any, Any}
81+
vars
82+
last_t::Float64
83+
function CacheSol(model, sol)
84+
vars = get_all_vars(model) |> unique
85+
Main.vars = vars
86+
values = sol(0.0, idxs=vars)
87+
new(model, sol, Dict(vars .=> values), vars, 0)
88+
end
89+
end
90+
91+
function get_all_vars(model, vars = Multibody.collect_all(unknowns(model)))
92+
for sys in model.systems
93+
if ModelingToolkit.isframe(sys)
94+
newvars = Multibody.ModelingToolkit.renamespace.(model.name, Multibody.Symbolics.unwrap.(vec(ori(sys).R)))
95+
append!(vars, newvars)
96+
else
97+
subsys_ns = getproperty(model, sys.name)
98+
get_all_vars(subsys_ns, vars)
99+
end
100+
end
101+
vars
102+
end
103+
104+
105+
function get_cached(cs::CacheSol, t, idxs)
106+
if ModelingToolkit.isparameter(idxs[1])
107+
return cs.prob.ps[idxs]
108+
end
109+
if idxs isa AbstractArray{Num}
110+
idxs = Multibody.Symbolics.unwrap.(idxs)
111+
end
112+
if !haskey(cs.cache, idxs[1])
113+
# Fallback for things not in cache
114+
return cs.sol(t; idxs)
115+
end
116+
if t != cs.last_t
117+
values = cs.sol(t, idxs=cs.vars)
118+
cs.cache = Dict(cs.vars .=> values)
119+
cs.last_t = t
120+
end
121+
if idxs isa Real
122+
return cs.cache[idxs]
123+
else
124+
return [cs.cache[i] for i in idxs]
125+
end
126+
end
127+
128+
129+
function getu(cs::CacheSol, syms)
130+
t->get_cached(cs::CacheSol, t.t, syms)
131+
end
132+
133+
function ModelingToolkit.parameter_values(cs::CacheSol)
134+
ModelingToolkit.parameter_values(cs.sol)
135+
# pars = Multibody.collect_all(parameters(cs.model))
136+
# cs.prob.ps[pars]
137+
end
138+
139+
function (cs::CacheSol)(t; idxs=nothing)
140+
if idxs === nothing
141+
cs.sol(t)
142+
else
143+
get_cached(cs, t, idxs)
144+
end
145+
end
146+
147+
function Base.getproperty(cs::CacheSol, s::Symbol)
148+
s fieldnames(typeof(cs)) && return getfield(cs, s)
149+
return getproperty(getfield(cs, :sol), s)
150+
end
151+
78152

79153

80154
"""
@@ -199,11 +273,14 @@ function render(model, sol,
199273
traces = nothing,
200274
display = false,
201275
loop = 1,
276+
cache = true,
202277
kwargs...
203278
)
204279
if sol isa ODEProblem
205280
sol = FakeSol(model, sol)
206281
return render(model, sol, 0; x, y, z, lookat, up, show_axis, kwargs...)[1]
282+
elseif cache
283+
sol = CacheSol(model, sol)
207284
end
208285
scene, fig = default_scene(x,y,z; lookat,up,show_axis)
209286
if timevec === nothing
@@ -258,6 +335,7 @@ function render(model, sol, time::Real;
258335
x = 2,
259336
y = 0.5,
260337
z = 2,
338+
cache = true,
261339
kwargs...,
262340
)
263341

@@ -266,6 +344,9 @@ function render(model, sol, time::Real;
266344
if sol isa ODEProblem
267345
sol = FakeSol(model, sol)
268346
end
347+
if cache
348+
sol = CacheSol(model, sol)
349+
end
269350

270351
# fig = Figure()
271352
# scene = LScene(fig[1, 1]).scene

0 commit comments

Comments
 (0)