Skip to content

Commit 65e59d0

Browse files
feat: update wrap_array_vars to handle history function
1 parent eb50d74 commit 65e59d0

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

src/systems/abstractsystem.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230230
end
231231

232232
function wrap_array_vars(
233-
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
233+
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
234+
inputs = nothing, history = false)
234235
isscalar = !(exprs isa AbstractArray)
235236
array_vars = Dict{Any, AbstractArray{Int}}()
236237
if dvs !== nothing
@@ -328,6 +329,19 @@ function wrap_array_vars(
328329
array_parameters[p] = (idxs, buffer_idx, sz)
329330
end
330331
end
332+
333+
inputind = if history
334+
uind + 2
335+
else
336+
uind + 1
337+
end
338+
params_offset = if history && hasinputs
339+
uind + 2
340+
elseif history || hasinputs
341+
uind + 1
342+
else
343+
uind
344+
end
331345
if isscalar
332346
function (expr)
333347
Func(
@@ -336,10 +350,10 @@ function wrap_array_vars(
336350
Let(
337351
vcat(
338352
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
339-
[k :(view($(expr.args[uind + hasinputs].name), $v))
353+
[k :(view($(expr.args[inputind].name), $v))
340354
for (k, v) in input_vars],
341355
[k :(reshape(
342-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
356+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
343357
$sz))
344358
for (k, (idxs, buffer_idx, sz)) in array_parameters],
345359
[k Code.MakeArray(v, symtype(k))
@@ -358,10 +372,10 @@ function wrap_array_vars(
358372
Let(
359373
vcat(
360374
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
361-
[k :(view($(expr.args[uind + hasinputs].name), $v))
375+
[k :(view($(expr.args[inputind].name), $v))
362376
for (k, v) in input_vars],
363377
[k :(reshape(
364-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
378+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
365379
$sz))
366380
for (k, (idxs, buffer_idx, sz)) in array_parameters],
367381
[k Code.MakeArray(v, symtype(k))
@@ -380,10 +394,10 @@ function wrap_array_vars(
380394
vcat(
381395
[k :(view($(expr.args[uind + 1].name), $v))
382396
for (k, v) in array_vars],
383-
[k :(view($(expr.args[uind + hasinputs + 1].name), $v))
397+
[k :(view($(expr.args[inputind + 1].name), $v))
384398
for (k, v) in input_vars],
385399
[k :(reshape(
386-
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
400+
view($(expr.args[params_offset + buffer_idx + 1].name),
387401
$idxs),
388402
$sz))
389403
for (k, (idxs, buffer_idx, sz)) in array_parameters],

0 commit comments

Comments
 (0)