Skip to content

Commit ba343fa

Browse files
Merge pull request #2940 from AayushSabharwal/as/wrap-array-inputs
fix: fix vectorization of array variables with inputs
2 parents 08287a4 + fe809dc commit ba343fa

File tree

3 files changed

+100
-28
lines changed

3 files changed

+100
-28
lines changed

src/systems/abstractsystem.jl

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ function wrap_assignments(isscalar, assignments; let_block = false)
224224
end
225225

226226
function wrap_array_vars(
227-
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys))
227+
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
228228
isscalar = !(exprs isa AbstractArray)
229229
array_vars = Dict{Any, AbstractArray{Int}}()
230230
if dvs !== nothing
@@ -235,16 +235,42 @@ function wrap_array_vars(
235235
push!(inds, j)
236236
end
237237
end
238+
for (k, inds) in array_vars
239+
if inds == (inds′ = inds[1]:inds[end])
240+
array_vars[k] = inds′
241+
end
242+
end
243+
238244
uind = 1
239245
else
240246
uind = 0
241247
end
242-
# tunables are scalarized and concatenated, so we need to have assignments
243-
# for the non-scalarized versions
244-
array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
245-
# Other parameters may be scalarized arrays but used in the vector form
248+
# values are (indexes, index of buffer, size of parameter)
249+
array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
250+
# If for some reason different elements of an array parameter are in different buffers
246251
other_array_parameters = Dict{Any, Any}()
247252

253+
hasinputs = inputs !== nothing
254+
input_vars = Dict{Any, AbstractArray{Int}}()
255+
if hasinputs
256+
for (j, x) in enumerate(inputs)
257+
if iscall(x) && operation(x) == getindex
258+
arg = arguments(x)[1]
259+
inds = get!(() -> Int[], input_vars, arg)
260+
push!(inds, j)
261+
end
262+
end
263+
for (k, inds) in input_vars
264+
if inds == (inds′ = inds[1]:inds[end])
265+
input_vars[k] = inds′
266+
end
267+
end
268+
end
269+
if has_index_cache(sys)
270+
ic = get_index_cache(sys)
271+
else
272+
ic = nothing
273+
end
248274
if ps isa Tuple && eltype(ps) <: AbstractArray
249275
ps = Iterators.flatten(ps)
250276
end
@@ -257,25 +283,33 @@ function wrap_array_vars(
257283
scal = collect(p)
258284
# all scalarized variables are in `ps`
259285
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
260-
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue
286+
(haskey(array_parameters, p) || haskey(other_array_parameters, p)) && continue
261287

262288
idx = parameter_index(sys, p)
263289
idx isa Int && continue
264290
if idx isa ParameterIndex
265291
if idx.portion != SciMLStructures.Tunable()
266292
continue
267293
end
268-
idxs = vec(idx.idx)
269-
sz = size(idx.idx)
294+
array_parameters[p] = (vec(idx.idx), 1, size(idx.idx))
270295
else
271296
# idx === nothing
272297
idxs = map(Base.Fix1(parameter_index, sys), scal)
273-
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
274-
idxs = map(x -> x.idx, idxs)
275-
end
276-
if !all(x -> x isa Int, idxs)
277-
other_array_parameters[p] = scal
278-
continue
298+
if first(idxs) isa ParameterIndex
299+
buffer_idxs = map(Base.Fix1(iterated_buffer_index, ic), idxs)
300+
if allequal(buffer_idxs)
301+
buffer_idx = first(buffer_idxs)
302+
if first(idxs).portion == SciMLStructures.Tunable()
303+
idxs = map(x -> x.idx, idxs)
304+
else
305+
idxs = map(x -> x.idx[end], idxs)
306+
end
307+
else
308+
other_array_parameters[p] = scal
309+
continue
310+
end
311+
else
312+
buffer_idx = 1
279313
end
280314

281315
sz = size(idxs)
@@ -285,12 +319,7 @@ function wrap_array_vars(
285319
idxs = idxs[begin]:-1:idxs[end]
286320
end
287321
idxs = vec(idxs)
288-
end
289-
array_tunables[p] = (idxs, sz)
290-
end
291-
for (k, inds) in array_vars
292-
if inds == (inds′ = inds[1]:inds[end])
293-
array_vars[k] = inds′
322+
array_parameters[p] = (idxs, buffer_idx, sz)
294323
end
295324
end
296325
if isscalar
@@ -301,8 +330,12 @@ function wrap_array_vars(
301330
Let(
302331
vcat(
303332
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
304-
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
305-
for (k, (idxs, sz)) in array_tunables],
333+
[k :(view($(expr.args[uind + hasinputs].name), $v))
334+
for (k, v) in input_vars],
335+
[k :(reshape(
336+
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
337+
$sz))
338+
for (k, (idxs, buffer_idx, sz)) in array_parameters],
306339
[k Code.MakeArray(v, symtype(k))
307340
for (k, v) in other_array_parameters]
308341
),
@@ -319,8 +352,12 @@ function wrap_array_vars(
319352
Let(
320353
vcat(
321354
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
322-
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
323-
for (k, (idxs, sz)) in array_tunables],
355+
[k :(view($(expr.args[uind + hasinputs].name), $v))
356+
for (k, v) in input_vars],
357+
[k :(reshape(
358+
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
359+
$sz))
360+
for (k, (idxs, buffer_idx, sz)) in array_parameters],
324361
[k Code.MakeArray(v, symtype(k))
325362
for (k, v) in other_array_parameters]
326363
),
@@ -337,8 +374,13 @@ function wrap_array_vars(
337374
vcat(
338375
[k :(view($(expr.args[uind + 1].name), $v))
339376
for (k, v) in array_vars],
340-
[k :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz))
341-
for (k, (idxs, sz)) in array_tunables],
377+
[k :(view($(expr.args[uind + hasinputs + 1].name), $v))
378+
for (k, v) in input_vars],
379+
[k :(reshape(
380+
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
381+
$idxs),
382+
$sz))
383+
for (k, (idxs, buffer_idx, sz)) in array_parameters],
342384
[k Code.MakeArray(v, symtype(k))
343385
for (k, v) in other_array_parameters]
344386
),

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,9 @@ function build_explicit_observed_function(sys, ts;
498498
pre = get_postprocess_fbody(sys)
499499

500500
array_wrapper = if param_only
501-
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing)
501+
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing, inputs)
502502
else
503-
wrap_array_vars(sys, ts; ps = _ps)
503+
wrap_array_vars(sys, ts; ps = _ps, inputs)
504504
end
505505
# Need to keep old method of building the function since it uses `output_type`,
506506
# which can't be provided to `build_function`

src/systems/index_cache.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,33 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
501501
end
502502
return result
503503
end
504+
505+
# Given a parameter index, find the index of the buffer it is in when
506+
# `MTKParameters` is iterated
507+
function iterated_buffer_index(ic::IndexCache, ind::ParameterIndex)
508+
idx = 0
509+
if ind.portion isa SciMLStructures.Tunable
510+
return idx + 1
511+
elseif ic.tunable_buffer_size.length > 0
512+
idx += 1
513+
end
514+
if ind.portion isa SciMLStructures.Discrete
515+
return idx + length(first(ic.discrete_buffer_sizes)) * (ind.idx[1] - 1) + ind.idx[2]
516+
elseif !isempty(ic.discrete_buffer_sizes)
517+
idx += length(first(ic.discrete_buffer_sizes)) * length(ic.discrete_buffer_sizes)
518+
end
519+
if ind.portion isa SciMLStructures.Constants
520+
return return idx + ind.idx[1]
521+
elseif !isempty(ic.constant_buffer_sizes)
522+
idx += length(ic.constant_buffer_sizes)
523+
end
524+
if ind.portion == DEPENDENT_PORTION
525+
return idx + ind.idx[1]
526+
elseif !isempty(ic.dependent_buffer_sizes)
527+
idx += length(ic.dependent_buffer_sizes)
528+
end
529+
if ind.portion == NONNUMERIC_PORTION
530+
return idx + ind.idx[1]
531+
end
532+
error("Unhandled portion $(ind.portion)")
533+
end

0 commit comments

Comments
 (0)