Skip to content

Commit dfc49b4

Browse files
feat: add extra_assignments to build_function_wrapper
1 parent 4274ec9 commit dfc49b4

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/systems/codegen_utils.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ function array_variable_assignments(args...)
5959
idxs = SArray{Tuple{size(idxs)...}}(idxs)
6060
end
6161
# view and reshape
62-
push!(assignments, arrvar term(reshape, term(view, generated_argument_name(buffer_idx), idxs), size(arrvar)))
62+
push!(assignments,
63+
arrvar
64+
term(reshape, term(view, generated_argument_name(buffer_idx), idxs),
65+
size(arrvar)))
6366
else
6467
elems = map(idxs) do idx
6568
i, j = idx
@@ -109,10 +112,17 @@ generated functions, and `args` are the arguments.
109112
code for `expr`.
110113
- `wrap_mtkparameters`: Whether to collapse parameter buffers for a split system into a
111114
argument.
115+
- `extra_assignments`: Extra `Assignment` statements to prefix to `expr`, after all other
116+
assignments.
112117
113118
All other keyword arguments are forwarded to `build_function`.
114119
"""
115-
function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), wrap_delays = is_dde(sys), wrap_code = identity, add_observed = true, filter_observed = Returns(true), create_bindings = true, output_type = nothing, mkarray = nothing, wrap_mtkparameters = true, kwargs...)
120+
function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
121+
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args),
122+
wrap_delays = is_dde(sys), wrap_code = identity,
123+
add_observed = true, filter_observed = Returns(true),
124+
create_bindings = true, output_type = nothing, mkarray = nothing,
125+
wrap_mtkparameters = true, extra_assignments = Assignment[], kwargs...)
116126
isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic())
117127
# filter observed equations
118128
obs = filter(filter_observed, observed(sys))
@@ -152,6 +162,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
152162
for eq in Iterators.flatten((cmap, pdeps[pdepidxs], obs[obsidxs]))
153163
push!(assignments, eq.lhs eq.rhs)
154164
end
165+
append!(assignments, extra_assignments)
155166

156167
args = ntuple(Val(length(args))) do i
157168
arg = args[i]

0 commit comments

Comments
 (0)