Skip to content

Commit 293e5e5

Browse files
feat: update wrap_array_vars to handle history function
1 parent 2dbdd4c commit 293e5e5

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/systems/abstractsystem.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230230
end
231231

232232
function wrap_array_vars(
233-
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
233+
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing, history = false)
234234
isscalar = !(exprs isa AbstractArray)
235235
array_vars = Dict{Any, AbstractArray{Int}}()
236236
if dvs !== nothing
@@ -328,6 +328,19 @@ function wrap_array_vars(
328328
array_parameters[p] = (idxs, buffer_idx, sz)
329329
end
330330
end
331+
332+
inputind = if history
333+
uind + 2
334+
else
335+
uind + 1
336+
end
337+
params_offset = if history && hasinputs
338+
uind + 2
339+
elseif history || hasinputs
340+
uind + 1
341+
else
342+
uind
343+
end
331344
if isscalar
332345
function (expr)
333346
Func(
@@ -336,10 +349,10 @@ function wrap_array_vars(
336349
Let(
337350
vcat(
338351
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
339-
[k :(view($(expr.args[uind + hasinputs].name), $v))
352+
[k :(view($(expr.args[inputind].name), $v))
340353
for (k, v) in input_vars],
341354
[k :(reshape(
342-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
355+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
343356
$sz))
344357
for (k, (idxs, buffer_idx, sz)) in array_parameters],
345358
[k Code.MakeArray(v, symtype(k))
@@ -358,10 +371,10 @@ function wrap_array_vars(
358371
Let(
359372
vcat(
360373
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
361-
[k :(view($(expr.args[uind + hasinputs].name), $v))
374+
[k :(view($(expr.args[inputind].name), $v))
362375
for (k, v) in input_vars],
363376
[k :(reshape(
364-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
377+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
365378
$sz))
366379
for (k, (idxs, buffer_idx, sz)) in array_parameters],
367380
[k Code.MakeArray(v, symtype(k))
@@ -380,10 +393,10 @@ function wrap_array_vars(
380393
vcat(
381394
[k :(view($(expr.args[uind + 1].name), $v))
382395
for (k, v) in array_vars],
383-
[k :(view($(expr.args[uind + hasinputs + 1].name), $v))
396+
[k :(view($(expr.args[inputind + 1].name), $v))
384397
for (k, v) in input_vars],
385398
[k :(reshape(
386-
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
399+
view($(expr.args[params_offset + buffer_idx + 1].name),
387400
$idxs),
388401
$sz))
389402
for (k, (idxs, buffer_idx, sz)) in array_parameters],

0 commit comments

Comments
 (0)