@@ -224,7 +224,7 @@ function wrap_assignments(isscalar, assignments; let_block = false)
224
224
end
225
225
226
226
function wrap_array_vars (
227
- sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys))
227
+ sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys), inputs = nothing )
228
228
isscalar = ! (exprs isa AbstractArray)
229
229
array_vars = Dict {Any, AbstractArray{Int}} ()
230
230
if dvs != = nothing
@@ -235,16 +235,42 @@ function wrap_array_vars(
235
235
push! (inds, j)
236
236
end
237
237
end
238
+ for (k, inds) in array_vars
239
+ if inds == (inds′ = inds[1 ]: inds[end ])
240
+ array_vars[k] = inds′
241
+ end
242
+ end
243
+
238
244
uind = 1
239
245
else
240
246
uind = 0
241
247
end
242
- # tunables are scalarized and concatenated, so we need to have assignments
243
- # for the non-scalarized versions
244
- array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
245
- # Other parameters may be scalarized arrays but used in the vector form
248
+ # values are (indexes, index of buffer, size of parameter)
249
+ array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
250
+ # If for some reason different elements of an array parameter are in different buffers
246
251
other_array_parameters = Dict {Any, Any} ()
247
252
253
+ hasinputs = inputs != = nothing
254
+ input_vars = Dict {Any, AbstractArray{Int}} ()
255
+ if hasinputs
256
+ for (j, x) in enumerate (inputs)
257
+ if iscall (x) && operation (x) == getindex
258
+ arg = arguments (x)[1 ]
259
+ inds = get! (() -> Int[], input_vars, arg)
260
+ push! (inds, j)
261
+ end
262
+ end
263
+ for (k, inds) in input_vars
264
+ if inds == (inds′ = inds[1 ]: inds[end ])
265
+ input_vars[k] = inds′
266
+ end
267
+ end
268
+ end
269
+ if has_index_cache (sys)
270
+ ic = get_index_cache (sys)
271
+ else
272
+ ic = nothing
273
+ end
248
274
if ps isa Tuple && eltype (ps) <: AbstractArray
249
275
ps = Iterators. flatten (ps)
250
276
end
@@ -257,25 +283,33 @@ function wrap_array_vars(
257
283
scal = collect (p)
258
284
# all scalarized variables are in `ps`
259
285
any (isequal (p), ps) || all (x -> any (isequal (x), ps), scal) || continue
260
- (haskey (array_tunables , p) || haskey (other_array_parameters, p)) && continue
286
+ (haskey (array_parameters , p) || haskey (other_array_parameters, p)) && continue
261
287
262
288
idx = parameter_index (sys, p)
263
289
idx isa Int && continue
264
290
if idx isa ParameterIndex
265
291
if idx. portion != SciMLStructures. Tunable ()
266
292
continue
267
293
end
268
- idxs = vec (idx. idx)
269
- sz = size (idx. idx)
294
+ array_parameters[p] = (vec (idx. idx), 1 , size (idx. idx))
270
295
else
271
296
# idx === nothing
272
297
idxs = map (Base. Fix1 (parameter_index, sys), scal)
273
- if all (x -> x isa ParameterIndex && x. portion isa SciMLStructures. Tunable, idxs)
274
- idxs = map (x -> x. idx, idxs)
275
- end
276
- if ! all (x -> x isa Int, idxs)
277
- other_array_parameters[p] = scal
278
- continue
298
+ if first (idxs) isa ParameterIndex
299
+ buffer_idxs = map (Base. Fix1 (iterated_buffer_index, ic), idxs)
300
+ if allequal (buffer_idxs)
301
+ buffer_idx = first (buffer_idxs)
302
+ if first (idxs). portion == SciMLStructures. Tunable ()
303
+ idxs = map (x -> x. idx, idxs)
304
+ else
305
+ idxs = map (x -> x. idx[end ], idxs)
306
+ end
307
+ else
308
+ other_array_parameters[p] = scal
309
+ continue
310
+ end
311
+ else
312
+ buffer_idx = 1
279
313
end
280
314
281
315
sz = size (idxs)
@@ -285,12 +319,7 @@ function wrap_array_vars(
285
319
idxs = idxs[begin ]: - 1 : idxs[end ]
286
320
end
287
321
idxs = vec (idxs)
288
- end
289
- array_tunables[p] = (idxs, sz)
290
- end
291
- for (k, inds) in array_vars
292
- if inds == (inds′ = inds[1 ]: inds[end ])
293
- array_vars[k] = inds′
322
+ array_parameters[p] = (idxs, buffer_idx, sz)
294
323
end
295
324
end
296
325
if isscalar
@@ -301,8 +330,12 @@ function wrap_array_vars(
301
330
Let (
302
331
vcat (
303
332
[k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
304
- [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
305
- for (k, (idxs, sz)) in array_tunables],
333
+ [k ← :(view ($ (expr. args[uind + hasinputs]. name), $ v))
334
+ for (k, v) in input_vars],
335
+ [k ← :(reshape (
336
+ view ($ (expr. args[uind + hasinputs + buffer_idx]. name), $ idxs),
337
+ $ sz))
338
+ for (k, (idxs, buffer_idx, sz)) in array_parameters],
306
339
[k ← Code. MakeArray (v, symtype (k))
307
340
for (k, v) in other_array_parameters]
308
341
),
@@ -319,8 +352,12 @@ function wrap_array_vars(
319
352
Let (
320
353
vcat (
321
354
[k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
322
- [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
323
- for (k, (idxs, sz)) in array_tunables],
355
+ [k ← :(view ($ (expr. args[uind + hasinputs]. name), $ v))
356
+ for (k, v) in input_vars],
357
+ [k ← :(reshape (
358
+ view ($ (expr. args[uind + hasinputs + buffer_idx]. name), $ idxs),
359
+ $ sz))
360
+ for (k, (idxs, buffer_idx, sz)) in array_parameters],
324
361
[k ← Code. MakeArray (v, symtype (k))
325
362
for (k, v) in other_array_parameters]
326
363
),
@@ -337,8 +374,13 @@ function wrap_array_vars(
337
374
vcat (
338
375
[k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
339
376
for (k, v) in array_vars],
340
- [k ← :(reshape (view ($ (expr. args[uind + 2 ]. name), $ idxs), $ sz))
341
- for (k, (idxs, sz)) in array_tunables],
377
+ [k ← :(view ($ (expr. args[uind + hasinputs + 1 ]. name), $ v))
378
+ for (k, v) in input_vars],
379
+ [k ← :(reshape (
380
+ view ($ (expr. args[uind + hasinputs + buffer_idx + 1 ]. name),
381
+ $ idxs),
382
+ $ sz))
383
+ for (k, (idxs, buffer_idx, sz)) in array_parameters],
342
384
[k ← Code. MakeArray (v, symtype (k))
343
385
for (k, v) in other_array_parameters]
344
386
),
0 commit comments