@@ -40,7 +40,6 @@ function recursive_extract(sol, s, depth=0)
40
40
end
41
41
end
42
42
43
-
44
43
function getu (sol:: FakeSol , syms)
45
44
t-> [recursive_extract (sol, s) for s in syms]
46
45
end
@@ -75,6 +74,81 @@ function Base.getproperty(sol::FakeSol, s::Symbol)
75
74
end
76
75
end
77
76
77
+ mutable struct CacheSol
78
+ model
79
+ sol
80
+ cache:: Dict{Any, Any}
81
+ vars
82
+ last_t:: Float64
83
+ function CacheSol (model, sol)
84
+ vars = get_all_vars (model) |> unique
85
+ Main. vars = vars
86
+ values = sol (0.0 , idxs= vars)
87
+ new (model, sol, Dict (vars .=> values), vars, 0 )
88
+ end
89
+ end
90
+
91
+ function get_all_vars (model, vars = Multibody. collect_all (unknowns (model)))
92
+ for sys in model. systems
93
+ if ModelingToolkit. isframe (sys)
94
+ newvars = Multibody. ModelingToolkit. renamespace .(model. name, Multibody. Symbolics. unwrap .(vec (ori (sys). R)))
95
+ append! (vars, newvars)
96
+ else
97
+ subsys_ns = getproperty (model, sys. name)
98
+ get_all_vars (subsys_ns, vars)
99
+ end
100
+ end
101
+ vars
102
+ end
103
+
104
+
105
+ function get_cached (cs:: CacheSol , t, idxs)
106
+ if ModelingToolkit. isparameter (idxs[1 ])
107
+ return cs. prob. ps[idxs]
108
+ end
109
+ if idxs isa AbstractArray{Num}
110
+ idxs = Multibody. Symbolics. unwrap .(idxs)
111
+ end
112
+ if ! haskey (cs. cache, idxs[1 ])
113
+ # Fallback for things not in cache
114
+ return cs. sol (t; idxs)
115
+ end
116
+ if t != cs. last_t
117
+ values = cs. sol (t, idxs= cs. vars)
118
+ cs. cache = Dict (cs. vars .=> values)
119
+ cs. last_t = t
120
+ end
121
+ if idxs isa Real
122
+ return cs. cache[idxs]
123
+ else
124
+ return [cs. cache[i] for i in idxs]
125
+ end
126
+ end
127
+
128
+
129
+ function getu (cs:: CacheSol , syms)
130
+ t-> get_cached (cs:: CacheSol , t. t, syms)
131
+ end
132
+
133
+ function ModelingToolkit. parameter_values (cs:: CacheSol )
134
+ ModelingToolkit. parameter_values (cs. sol)
135
+ # pars = Multibody.collect_all(parameters(cs.model))
136
+ # cs.prob.ps[pars]
137
+ end
138
+
139
+ function (cs:: CacheSol )(t; idxs= nothing )
140
+ if idxs === nothing
141
+ cs. sol (t)
142
+ else
143
+ get_cached (cs, t, idxs)
144
+ end
145
+ end
146
+
147
+ function Base. getproperty (cs:: CacheSol , s:: Symbol )
148
+ s ∈ fieldnames (typeof (cs)) && return getfield (cs, s)
149
+ return getproperty (getfield (cs, :sol ), s)
150
+ end
151
+
78
152
79
153
80
154
"""
@@ -199,11 +273,14 @@ function render(model, sol,
199
273
traces = nothing ,
200
274
display = false ,
201
275
loop = 1 ,
276
+ cache = true ,
202
277
kwargs...
203
278
)
204
279
if sol isa ODEProblem
205
280
sol = FakeSol (model, sol)
206
281
return render (model, sol, 0 ; x, y, z, lookat, up, show_axis, kwargs... )[1 ]
282
+ elseif cache
283
+ sol = CacheSol (model, sol)
207
284
end
208
285
scene, fig = default_scene (x,y,z; lookat,up,show_axis)
209
286
if timevec === nothing
@@ -258,6 +335,7 @@ function render(model, sol, time::Real;
258
335
x = 2 ,
259
336
y = 0.5 ,
260
337
z = 2 ,
338
+ cache = true ,
261
339
kwargs... ,
262
340
)
263
341
@@ -266,6 +344,9 @@ function render(model, sol, time::Real;
266
344
if sol isa ODEProblem
267
345
sol = FakeSol (model, sol)
268
346
end
347
+ if cache
348
+ sol = CacheSol (model, sol)
349
+ end
269
350
270
351
# fig = Figure()
271
352
# scene = LScene(fig[1, 1]).scene
0 commit comments