@@ -162,11 +162,12 @@ object.
162162"""
163163function generate_custom_function (sys:: AbstractSystem , exprs, dvs = unknowns (sys),
164164 ps = parameters (sys); wrap_code = nothing , postprocess_fbody = nothing , states = nothing ,
165- expression = Val{true }, eval_expression = false , eval_module = @__MODULE__ , kwargs... )
165+ expression = Val{true }, eval_expression = false , eval_module = @__MODULE__ ,
166+ cachesyms:: Tuple = (), kwargs... )
166167 if ! iscomplete (sys)
167168 error (" A completed system is required. Call `complete` or `structural_simplify` on the system." )
168169 end
169- p = reorder_parameters (sys, unwrap .(ps))
170+ p = ( reorder_parameters (sys, unwrap .(ps)) ... , cachesyms ... )
170171 isscalar = ! (exprs isa AbstractArray)
171172 if wrap_code === nothing
172173 wrap_code = isscalar ? identity : (identity, identity)
@@ -187,7 +188,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
187188 postprocess_fbody,
188189 states,
189190 wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
190- wrap_array_vars (sys, exprs; dvs) .∘
191+ wrap_array_vars (sys, exprs; dvs, cachesyms ) .∘
191192 wrap_parameter_dependencies (sys, isscalar),
192193 expression = Val{true }
193194 )
@@ -199,7 +200,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
199200 postprocess_fbody,
200201 states,
201202 wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
202- wrap_array_vars (sys, exprs; dvs) .∘
203+ wrap_array_vars (sys, exprs; dvs, cachesyms ) .∘
203204 wrap_parameter_dependencies (sys, isscalar),
204205 expression = Val{true }
205206 )
@@ -231,133 +232,59 @@ end
231232
232233function wrap_array_vars (
233234 sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys),
234- inputs = nothing , history = false )
235+ inputs = nothing , history = false , cachesyms :: Tuple = () )
235236 isscalar = ! (exprs isa AbstractArray)
236- array_vars = Dict {Any, AbstractArray{Int}} ()
237- if dvs != = nothing
238- for (j, x) in enumerate (dvs)
239- if iscall (x) && operation (x) == getindex
240- arg = arguments (x)[1 ]
241- inds = get! (() -> Int[], array_vars, arg)
242- push! (inds, j)
243- end
244- end
245- for (k, inds) in array_vars
246- if inds == (inds′ = inds[1 ]: inds[end ])
247- array_vars[k] = inds′
248- end
249- end
237+ var_to_arridxs = Dict ()
250238
251- uind = 1
252- else
239+ if dvs === nothing
253240 uind = 0
254- end
255- # values are (indexes, index of buffer, size of parameter)
256- array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
257- # If for some reason different elements of an array parameter are in different buffers
258- other_array_parameters = Dict {Any, Any} ()
259-
260- hasinputs = inputs != = nothing
261- input_vars = Dict {Any, AbstractArray{Int}} ()
262- if hasinputs
263- for (j, x) in enumerate (inputs)
264- if iscall (x) && operation (x) == getindex
265- arg = arguments (x)[1 ]
266- inds = get! (() -> Int[], input_vars, arg)
267- push! (inds, j)
268- end
269- end
270- for (k, inds) in input_vars
271- if inds == (inds′ = inds[1 ]: inds[end ])
272- input_vars[k] = inds′
273- end
274- end
275- end
276- if has_index_cache (sys)
277- ic = get_index_cache (sys)
278241 else
279- ic = nothing
280- end
281- if ps isa Tuple && eltype (ps) <: AbstractArray
282- ps = Iterators. flatten (ps)
283- end
284- for p in ps
285- p = unwrap (p)
286- if iscall (p) && operation (p) == getindex
287- p = arguments (p)[1 ]
288- end
289- symtype (p) <: AbstractArray && Symbolics. shape (p) != Symbolics. Unknown () || continue
290- scal = collect (p)
291- # all scalarized variables are in `ps`
292- any (isequal (p), ps) || all (x -> any (isequal (x), ps), scal) || continue
293- (haskey (array_parameters, p) || haskey (other_array_parameters, p)) && continue
294-
295- idx = parameter_index (sys, p)
296- idx isa Int && continue
297- if idx isa ParameterIndex
298- if idx. portion != SciMLStructures. Tunable ()
299- continue
300- end
301- array_parameters[p] = (vec (idx. idx), 1 , size (idx. idx))
242+ uind = 1
243+ for (i, x) in enumerate (dvs)
244+ iscall (x) && operation (x) == getindex || continue
245+ arg = arguments (x)[1 ]
246+ inds = get! (() -> [], var_to_arridxs, arg)
247+ push! (inds, (uind, i))
248+ end
249+ end
250+ p_start = uind + 1 + history
251+ rps = (reorder_parameters (sys, ps)... , cachesyms... )
252+ if inputs != = nothing
253+ rps = (inputs, rps... )
254+ end
255+ for sym in reduce (vcat, rps; init = [])
256+ iscall (sym) && operation (sym) == getindex || continue
257+ arg = arguments (sym)[1 ]
258+
259+ bufferidx = findfirst (buf -> any (isequal (sym), buf), rps)
260+ idxinbuffer = findfirst (isequal (sym), rps[bufferidx])
261+ inds = get! (() -> [], var_to_arridxs, arg)
262+ push! (inds, (p_start + bufferidx - 1 , idxinbuffer))
263+ end
264+
265+ viewsyms = Dict ()
266+ splitsyms = Dict ()
267+ for (arrsym, idxs) in var_to_arridxs
268+ length (idxs) == length (arrsym) || continue
269+ # allequal(first, idxs) is a 1.11 feature
270+ if allequal (Iterators. map (first, idxs))
271+ viewsyms[arrsym] = (first (first (idxs)), reshape (last .(idxs), size (arrsym)))
302272 else
303- # idx === nothing
304- idxs = map (Base. Fix1 (parameter_index, sys), scal)
305- if first (idxs) isa ParameterIndex
306- buffer_idxs = map (Base. Fix1 (iterated_buffer_index, ic), idxs)
307- if allequal (buffer_idxs)
308- buffer_idx = first (buffer_idxs)
309- if first (idxs). portion == SciMLStructures. Tunable ()
310- idxs = map (x -> x. idx, idxs)
311- else
312- idxs = map (x -> x. idx[end ], idxs)
313- end
314- else
315- other_array_parameters[p] = scal
316- continue
317- end
318- else
319- buffer_idx = 1
320- end
321-
322- sz = size (idxs)
323- if vec (idxs) == idxs[begin ]: idxs[end ]
324- idxs = idxs[begin ]: idxs[end ]
325- elseif vec (idxs) == idxs[begin ]: - 1 : idxs[end ]
326- idxs = idxs[begin ]: - 1 : idxs[end ]
327- end
328- idxs = vec (idxs)
329- array_parameters[p] = (idxs, buffer_idx, sz)
273+ splitsyms[arrsym] = reshape (idxs, size (arrsym))
330274 end
331275 end
332-
333- inputind = if history
334- uind + 2
335- else
336- uind + 1
337- end
338- params_offset = if history && hasinputs
339- uind + 2
340- elseif history || hasinputs
341- uind + 1
342- else
343- uind
344- end
345276 if isscalar
346277 function (expr)
347278 Func (
348279 expr. args,
349280 [],
350281 Let (
351282 vcat (
352- [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
353- [k ← :(view ($ (expr. args[inputind]. name), $ v))
354- for (k, v) in input_vars],
355- [k ← :(reshape (
356- view ($ (expr. args[params_offset + buffer_idx]. name), $ idxs),
357- $ sz))
358- for (k, (idxs, buffer_idx, sz)) in array_parameters],
359- [k ← Code. MakeArray (v, symtype (k))
360- for (k, v) in other_array_parameters]
283+ [sym ← :(view ($ (expr. args[i]. name), $ idxs))
284+ for (sym, (i, idxs)) in viewsyms],
285+ [sym ←
286+ MakeArray ([expr. args[bufi]. elems[vali] for (bufi, vali) in idxs],
287+ expr. args[idxs[1 ][1 ]]) for (sym, idxs) in splitsyms]
361288 ),
362289 expr. body,
363290 false
@@ -371,15 +298,11 @@ function wrap_array_vars(
371298 [],
372299 Let (
373300 vcat (
374- [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
375- [k ← :(view ($ (expr. args[inputind]. name), $ v))
376- for (k, v) in input_vars],
377- [k ← :(reshape (
378- view ($ (expr. args[params_offset + buffer_idx]. name), $ idxs),
379- $ sz))
380- for (k, (idxs, buffer_idx, sz)) in array_parameters],
381- [k ← Code. MakeArray (v, symtype (k))
382- for (k, v) in other_array_parameters]
301+ [sym ← :(view ($ (expr. args[i]. name), $ idxs))
302+ for (sym, (i, idxs)) in viewsyms],
303+ [sym ←
304+ MakeArray ([expr. args[bufi]. elems[vali] for (bufi, vali) in idxs],
305+ expr. args[idxs[1 ][1 ]]) for (sym, idxs) in splitsyms]
383306 ),
384307 expr. body,
385308 false
@@ -392,17 +315,11 @@ function wrap_array_vars(
392315 [],
393316 Let (
394317 vcat (
395- [k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
396- for (k, v) in array_vars],
397- [k ← :(view ($ (expr. args[inputind + 1 ]. name), $ v))
398- for (k, v) in input_vars],
399- [k ← :(reshape (
400- view ($ (expr. args[params_offset + buffer_idx + 1 ]. name),
401- $ idxs),
402- $ sz))
403- for (k, (idxs, buffer_idx, sz)) in array_parameters],
404- [k ← Code. MakeArray (v, symtype (k))
405- for (k, v) in other_array_parameters]
318+ [sym ← :(view ($ (expr. args[i + 1 ]. name), $ idxs))
319+ for (sym, (i, idxs)) in viewsyms],
320+ [sym ← MakeArray (
321+ [expr. args[bufi + 1 ]. elems[vali] for (bufi, vali) in idxs],
322+ expr. args[idxs[1 ][1 ] + 1 ]) for (sym, idxs) in splitsyms]
406323 ),
407324 expr. body,
408325 false
0 commit comments