@@ -224,7 +224,7 @@ function wrap_assignments(isscalar, assignments; let_block = false)
224224end
225225
226226function wrap_array_vars (
227- sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys))
227+ sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys), inputs = nothing )
228228 isscalar = ! (exprs isa AbstractArray)
229229 array_vars = Dict {Any, AbstractArray{Int}} ()
230230 if dvs != = nothing
@@ -235,16 +235,42 @@ function wrap_array_vars(
235235 push! (inds, j)
236236 end
237237 end
238+ for (k, inds) in array_vars
239+ if inds == (inds′ = inds[1 ]: inds[end ])
240+ array_vars[k] = inds′
241+ end
242+ end
243+
238244 uind = 1
239245 else
240246 uind = 0
241247 end
242- # tunables are scalarized and concatenated, so we need to have assignments
243- # for the non-scalarized versions
244- array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
245- # Other parameters may be scalarized arrays but used in the vector form
248+ # values are (indexes, index of buffer, size of parameter)
249+ array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
250+ # If for some reason different elements of an array parameter are in different buffers
246251 other_array_parameters = Dict {Any, Any} ()
247252
253+ hasinputs = inputs != = nothing
254+ input_vars = Dict {Any, AbstractArray{Int}} ()
255+ if hasinputs
256+ for (j, x) in enumerate (inputs)
257+ if iscall (x) && operation (x) == getindex
258+ arg = arguments (x)[1 ]
259+ inds = get! (() -> Int[], input_vars, arg)
260+ push! (inds, j)
261+ end
262+ end
263+ for (k, inds) in input_vars
264+ if inds == (inds′ = inds[1 ]: inds[end ])
265+ input_vars[k] = inds′
266+ end
267+ end
268+ end
269+ if has_index_cache (sys)
270+ ic = get_index_cache (sys)
271+ else
272+ ic = nothing
273+ end
248274 if ps isa Tuple && eltype (ps) <: AbstractArray
249275 ps = Iterators. flatten (ps)
250276 end
@@ -257,25 +283,33 @@ function wrap_array_vars(
257283 scal = collect (p)
258284 # all scalarized variables are in `ps`
259285 any (isequal (p), ps) || all (x -> any (isequal (x), ps), scal) || continue
260- (haskey (array_tunables , p) || haskey (other_array_parameters, p)) && continue
286+ (haskey (array_parameters , p) || haskey (other_array_parameters, p)) && continue
261287
262288 idx = parameter_index (sys, p)
263289 idx isa Int && continue
264290 if idx isa ParameterIndex
265291 if idx. portion != SciMLStructures. Tunable ()
266292 continue
267293 end
268- idxs = vec (idx. idx)
269- sz = size (idx. idx)
294+ array_parameters[p] = (vec (idx. idx), 1 , size (idx. idx))
270295 else
271296 # idx === nothing
272297 idxs = map (Base. Fix1 (parameter_index, sys), scal)
273- if all (x -> x isa ParameterIndex && x. portion isa SciMLStructures. Tunable, idxs)
274- idxs = map (x -> x. idx, idxs)
275- end
276- if ! all (x -> x isa Int, idxs)
277- other_array_parameters[p] = scal
278- continue
298+ if first (idxs) isa ParameterIndex
299+ buffer_idxs = map (Base. Fix1 (iterated_buffer_index, ic), idxs)
300+ if allequal (buffer_idxs)
301+ buffer_idx = first (buffer_idxs)
302+ if first (idxs). portion == SciMLStructures. Tunable ()
303+ idxs = map (x -> x. idx, idxs)
304+ else
305+ idxs = map (x -> x. idx[end ], idxs)
306+ end
307+ else
308+ other_array_parameters[p] = scal
309+ continue
310+ end
311+ else
312+ buffer_idx = 1
279313 end
280314
281315 sz = size (idxs)
@@ -285,12 +319,7 @@ function wrap_array_vars(
285319 idxs = idxs[begin ]: - 1 : idxs[end ]
286320 end
287321 idxs = vec (idxs)
288- end
289- array_tunables[p] = (idxs, sz)
290- end
291- for (k, inds) in array_vars
292- if inds == (inds′ = inds[1 ]: inds[end ])
293- array_vars[k] = inds′
322+ array_parameters[p] = (idxs, buffer_idx, sz)
294323 end
295324 end
296325 if isscalar
@@ -301,8 +330,12 @@ function wrap_array_vars(
301330 Let (
302331 vcat (
303332 [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
304- [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
305- for (k, (idxs, sz)) in array_tunables],
333+ [k ← :(view ($ (expr. args[uind + hasinputs]. name), $ v))
334+ for (k, v) in input_vars],
335+ [k ← :(reshape (
336+ view ($ (expr. args[uind + hasinputs + buffer_idx]. name), $ idxs),
337+ $ sz))
338+ for (k, (idxs, buffer_idx, sz)) in array_parameters],
306339 [k ← Code. MakeArray (v, symtype (k))
307340 for (k, v) in other_array_parameters]
308341 ),
@@ -319,8 +352,12 @@ function wrap_array_vars(
319352 Let (
320353 vcat (
321354 [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
322- [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
323- for (k, (idxs, sz)) in array_tunables],
355+ [k ← :(view ($ (expr. args[uind + hasinputs]. name), $ v))
356+ for (k, v) in input_vars],
357+ [k ← :(reshape (
358+ view ($ (expr. args[uind + hasinputs + buffer_idx]. name), $ idxs),
359+ $ sz))
360+ for (k, (idxs, buffer_idx, sz)) in array_parameters],
324361 [k ← Code. MakeArray (v, symtype (k))
325362 for (k, v) in other_array_parameters]
326363 ),
@@ -337,8 +374,13 @@ function wrap_array_vars(
337374 vcat (
338375 [k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
339376 for (k, v) in array_vars],
340- [k ← :(reshape (view ($ (expr. args[uind + 2 ]. name), $ idxs), $ sz))
341- for (k, (idxs, sz)) in array_tunables],
377+ [k ← :(view ($ (expr. args[uind + hasinputs + 1 ]. name), $ v))
378+ for (k, v) in input_vars],
379+ [k ← :(reshape (
380+ view ($ (expr. args[uind + hasinputs + buffer_idx + 1 ]. name),
381+ $ idxs),
382+ $ sz))
383+ for (k, (idxs, buffer_idx, sz)) in array_parameters],
342384 [k ← Code. MakeArray (v, symtype (k))
343385 for (k, v) in other_array_parameters]
344386 ),
0 commit comments