Skip to content

Commit 80136d9

Browse files
docs: add docstring for code generation utils
1 parent a8218eb commit 80136d9

File tree

1 file changed

+82
-3
lines changed

1 file changed

+82
-3
lines changed

src/systems/codegen_utils.jl

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,134 @@
1+
"""
2+
$(TYPEDSIGNATURES)
3+
4+
Return the name for the `i`th argument in a function generated by `build_function_wrapper`.
5+
"""
16
function generated_argument_name(i::Int)
27
return Symbol(:__mtk_arg_, i)
38
end
49

10+
"""
11+
$(TYPEDSIGNATURES)
12+
13+
Given the arguments to `build_function_wrapper`, return a list of assignments which
14+
reconstruct array variables if they are present scalarized in `args`.
15+
"""
516
function array_variable_assignments(args...)
17+
# map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
618
var_to_arridxs = Dict{BasicSymbolic, Array{Tuple{Int, Int}}}()
719
for (i, arg) in enumerate(args)
20+
# filter out non-arrays
21+
# any element of args which is not an array is assumed to not contain a
22+
# scalarized array symbolic. This works because the only non-array element
23+
# is the independent variable
824
symbolic_type(arg) == NotSymbolic() || continue
925
arg isa AbstractArray || continue
1026

27+
# go through symbolics
1128
for (j, var) in enumerate(arg)
1229
var = unwrap(var)
30+
# filter out non-array-symbolics
1331
iscall(var) || continue
1432
operation(var) == getindex || continue
1533
arrvar = arguments(var)[1]
34+
# get and/or construct the buffer storing indexes
1635
idxbuffer = get!(() -> map(Returns((0, 0)), eachindex(arrvar)), var_to_arridxs, arrvar)
1736
idxbuffer[arguments(var)[2:end]...] = (i, j)
1837
end
1938
end
2039

2140
assignments = Assignment[]
2241
for (arrvar, idxs) in var_to_arridxs
42+
# all elements of the array need to be present in `args` to form the
43+
# reconstructing assignment
2344
any(iszero first, idxs) && continue
2445

46+
# if they are all in the same buffer, we can take a shortcut and `view` into it
2547
if allequal(Iterators.map(first, idxs))
2648
buffer_idx = first(first(idxs))
2749
idxs = map(last, idxs)
50+
# if all the elements are contiguous and ordered, turn the array of indexes into a range
51+
# to help reduce allocations
2852
if first(idxs) < last(idxs) && vec(idxs) == first(idxs):last(idxs)
2953
idxs = first(idxs):last(idxs)
3054
elseif vec(idxs) == last(idxs):-1:first(idxs)
3155
idxs = last(idxs):-1:first(idxs)
3256
else
57+
# Otherwise, turn the indexes into an `SArray` so they're stack-allocated
3358
idxs = SArray{Tuple{size(idxs)...}}(idxs)
3459
end
60+
# view and reshape
3561
push!(assignments, arrvar term(reshape, term(view, generated_argument_name(buffer_idx), idxs), size(arrvar)))
3662
else
3763
elems = map(idxs) do idx
3864
i, j = idx
3965
term(getindex, generated_argument_name(i), j)
4066
end
67+
# use `MakeArray` and generate a stack-allocated array
4168
push!(assignments, arrvar MakeArray(elems, SArray))
4269
end
4370
end
4471

4572
return assignments
4673
end
4774

75+
"""
76+
$(TYPEDSIGNATURES)
77+
78+
A wrapper around `build_function` which performs the necessary transformations for
79+
code generation of all types of systems. `expr` is the expression returned from the
80+
generated functions, and `args` are the arguments.
81+
82+
# Keyword Arguments
83+
84+
- `p_start`, `p_end`: Denotes the indexes in `args` where the buffers of the splatted
85+
`MTKParameters` object are present. These are collapsed into a single argument and
86+
destructured inside the function. `p_start` must also be provided for non-split systems
87+
since it is used by `wrap_delays`.
88+
- `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
89+
calls to a history function. The history function is added to the list of arguments
90+
right before parameters, at the index `p_start`.
91+
- `wrap_code`: Forwarded to `build_function`.
92+
- `add_observed`: Whether to add assignment statements for observed equations in the
93+
generated code.
94+
- `filter_observed`: A predicate function to filter out observed equations which should
95+
not be added to the generated code.
96+
- `create_bindings`: Whether to explicitly destructure arrays of symbolics present in
97+
`args` in the generated code. If `false`, all usages of the individual symbolics will
98+
instead call `getindex` on the relevant argument. This is useful if the generated
99+
function writes to one of its arguments and expects subsequent code to use the new
100+
values. Note that the collapsed `MTKParameters` argument will always be explicitly
101+
destructured regardless of this keyword argument.
102+
- `output_type`: The type of the output buffer. If `mkarray` (see below) is `nothing`,
103+
this will be passed to the `similarto` argument of `build_function`. If `output_type`
104+
is `Tuple`, `expr` will be wrapped in `SymbolicUtils.Code.MakeTuple` (regardless of
105+
whether it is scalar or an array).
106+
- `mkarray`: A function which accepts `expr` and `output_type` and returns a code
107+
generation object similar to `MakeArray` or `MakeTuple` to be used to generate
108+
code for `expr`.
109+
- `wrap_mtkparameters`: Whether to collapse parameter buffers for a split system into a
110+
argument.
111+
112+
All other keyword arguments are forwarded to `build_function`.
113+
"""
48114
function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), wrap_delays = is_dde(sys), wrap_code = identity, add_observed = true, filter_observed = Returns(true), create_bindings = true, output_type = nothing, mkarray = nothing, wrap_mtkparameters = true, kwargs...)
49115
isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic())
50-
116+
# filter observed equations
51117
obs = filter(filter_observed, observed(sys))
118+
# turn delayed unknowns into calls to the history function
52119
if wrap_delays
53120
history_arg = is_split(sys) ? MTKPARAMETERS_ARG : generated_argument_name(p_start)
54121
obs = map(obs) do eq
55122
delay_to_function(sys, eq; history_arg)
56123
end
57124
expr = delay_to_function(sys, expr; history_arg)
125+
# add extra argument
58126
args = (args[1:p_start-1]..., DDE_HISTORY_FUN, args[p_start:end]...)
59127
p_start += 1
60128
p_end += 1
61129
end
62130
pdeps = parameter_dependencies(sys)
63-
131+
# get the constants to add to the code
64132
cmap, _ = get_cmap(sys)
65133
extra_constants = collect_constants(expr)
66134
filter!(extra_constants) do c
@@ -69,13 +137,15 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
69137
for c in extra_constants
70138
push!(cmap, c ~ getdefault(c))
71139
end
140+
# only get the necessary observed equations, avoiding extra computation
72141
if add_observed
73142
obsidxs = observed_equations_used_by(sys, expr)
74143
else
75144
obsidxs = Int[]
76145
end
146+
# similarly for parameter dependency equations
77147
pdepidxs = observed_equations_used_by(sys, expr; obs = pdeps)
78-
148+
# assignments for reconstructing scalarized array symbolics
79149
assignments = array_variable_assignments(args...)
80150

81151
for eq in Iterators.flatten((cmap, pdeps[pdepidxs], obs[obsidxs]))
@@ -84,35 +154,43 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
84154

85155
args = ntuple(Val(length(args))) do i
86156
arg = args[i]
157+
# for time-dependent systems, all arguments are passed through `time_varying_as_func`
158+
# TODO: This is legacy behavior and a candidate for removal in v10 since we have callable
159+
# parameters now.
87160
if is_time_dependent(sys)
88161
arg = if symbolic_type(arg) == NotSymbolic()
89162
arg isa AbstractArray ? map(x -> time_varying_as_func(unwrap(x), sys), arg) : arg
90163
else
91164
time_varying_as_func(unwrap(arg), sys)
92165
end
93166
end
167+
# Make sure to use the proper names for arguments
94168
if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
95169
DestructuredArgs(arg, generated_argument_name(i); create_bindings)
96170
else
97171
arg
98172
end
99173
end
100174

175+
# wrap into a single MTKParameters argument
101176
if is_split(sys) && wrap_mtkparameters
102177
if p_start > p_end
178+
# In case there are no parameter buffers, still insert an argument
103179
args = (args[1:p_start-1]..., MTKPARAMETERS_ARG, args[p_end+1:end]...)
104180
else
105181
# cannot apply `create_bindings` here since it doesn't nest
106182
args = (args[1:p_start-1]..., DestructuredArgs(collect(args[p_start:p_end]), MTKPARAMETERS_ARG), args[p_end+1:end]...)
107183
end
108184
end
109185

186+
# add preface assignments
110187
if has_preface(sys) && (pref = preface(sys)) !== nothing
111188
append!(assignments, pref)
112189
end
113190

114191
wrap_code = wrap_code .∘ wrap_assignments(isscalar, assignments)
115192

193+
# handling of `output_type` and `mkarray`
116194
similarto = nothing
117195
if output_type === Tuple
118196
expr = MakeTuple(Tuple(expr))
@@ -124,6 +202,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
124202
wrap_code = wrap_code[2]
125203
end
126204

205+
# scalar `build_function` only accepts a single function for `wrap_code`.
127206
if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic()
128207
wrap_code = wrap_code[1]
129208
end

0 commit comments

Comments
 (0)