Skip to content

Commit cb6ca4c

Browse files
feat: add support for extra_args in wrap_array_vars
1 parent a3789ae commit cb6ca4c

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

src/systems/abstractsystem.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,33 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230230
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
231231
end
232232

233+
"""
234+
$(TYPEDSIGNATURES)
235+
236+
Add the necessary assignment statements to allow use of unscalarized array variables
237+
in the generated code. `expr` is the expression returned by the function. `dvs` and
238+
`ps` are the unknowns and parameters of the system `sys` to use in the generated code.
239+
`inputs` can be specified as an array of symbolics if the generated function has inputs.
240+
If `history == true`, the generated function accepts a history function. `cachesyms` are
241+
extra variables (arrays of variables) stored in the cache array(s) of the parameter
242+
object. `extra_args` are extra arguments appended to the end of the argument list.
243+
244+
The function is assumed to have the signature `f(du, u, h, x, p, cache_syms..., t, extra_args...)`
245+
Where:
246+
- `du` is the optional buffer to write to for in-place functions.
247+
- `u` is the list of unknowns. This argument is not present if `dvs === nothing`.
248+
- `h` is the optional history function, present if `history == true`.
249+
- `x` is the array of inputs, present only if `inputs !== nothing`. Values are assumed
250+
to be in the order of variables passed to `inputs`.
251+
- `p` is the parameter object.
252+
- `cache_syms` are the cache variables. These are part of the splatted parameter object.
253+
- `t` is time, present only if the system is time dependent.
254+
- `extra_args` are the extra arguments passed to the function, present only if
255+
`extra_args` is non-empty.
256+
"""
233257
function wrap_array_vars(
234258
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
235-
inputs = nothing, history = false, cachesyms::Tuple = ())
259+
inputs = nothing, history = false, cachesyms::Tuple = (), extra_args::Tuple = ())
236260
isscalar = !(exprs isa AbstractArray)
237261
var_to_arridxs = Dict()
238262

@@ -252,6 +276,10 @@ function wrap_array_vars(
252276
if inputs !== nothing
253277
rps = (inputs, rps...)
254278
end
279+
if has_iv(sys)
280+
rps = (rps..., get_iv(sys))
281+
end
282+
rps = (rps..., extra_args...)
255283
for sym in reduce(vcat, rps; init = [])
256284
iscall(sym) && operation(sym) == getindex || continue
257285
arg = arguments(sym)[1]

0 commit comments

Comments
 (0)