Skip to content

Commit 4274ec9

Browse files
docs: add docstring for code generation utils
1 parent cff02c9 commit 4274ec9

File tree

1 file changed

+86
-6
lines changed

1 file changed

+86
-6
lines changed

src/systems/codegen_utils.jl

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,135 @@
1+
"""
2+
$(TYPEDSIGNATURES)
3+
4+
Return the name for the `i`th argument in a function generated by `build_function_wrapper`.
5+
"""
16
function generated_argument_name(i::Int)
27
return Symbol(:__mtk_arg_, i)
38
end
49

10+
"""
11+
$(TYPEDSIGNATURES)
12+
13+
Given the arguments to `build_function_wrapper`, return a list of assignments which
14+
reconstruct array variables if they are present scalarized in `args`.
15+
"""
516
function array_variable_assignments(args...)
17+
# map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
618
var_to_arridxs = Dict{BasicSymbolic, Array{Tuple{Int, Int}}}()
719
for (i, arg) in enumerate(args)
20+
# filter out non-arrays
21+
# any element of args which is not an array is assumed to not contain a
22+
# scalarized array symbolic. This works because the only non-array element
23+
# is the independent variable
824
symbolic_type(arg) == NotSymbolic() || continue
925
arg isa AbstractArray || continue
1026

27+
# go through symbolics
1128
for (j, var) in enumerate(arg)
1229
var = unwrap(var)
30+
# filter out non-array-symbolics
1331
iscall(var) || continue
1432
operation(var) == getindex || continue
1533
arrvar = arguments(var)[1]
16-
idxbuffer = get!(() -> map(Returns((0, 0)), eachindex(arrvar)), var_to_arridxs, arrvar)
34+
# get and/or construct the buffer storing indexes
35+
idxbuffer = get!(
36+
() -> map(Returns((0, 0)), eachindex(arrvar)), var_to_arridxs, arrvar)
1737
idxbuffer[arguments(var)[2:end]...] = (i, j)
1838
end
1939
end
2040

2141
assignments = Assignment[]
2242
for (arrvar, idxs) in var_to_arridxs
43+
# all elements of the array need to be present in `args` to form the
44+
# reconstructing assignment
2345
any(iszero first, idxs) && continue
2446

47+
# if they are all in the same buffer, we can take a shortcut and `view` into it
2548
if allequal(Iterators.map(first, idxs))
2649
buffer_idx = first(first(idxs))
2750
idxs = map(last, idxs)
51+
# if all the elements are contiguous and ordered, turn the array of indexes into a range
52+
# to help reduce allocations
2853
if first(idxs) < last(idxs) && vec(idxs) == first(idxs):last(idxs)
2954
idxs = first(idxs):last(idxs)
3055
elseif vec(idxs) == last(idxs):-1:first(idxs)
3156
idxs = last(idxs):-1:first(idxs)
3257
else
58+
# Otherwise, turn the indexes into an `SArray` so they're stack-allocated
3359
idxs = SArray{Tuple{size(idxs)...}}(idxs)
3460
end
61+
# view and reshape
3562
push!(assignments, arrvar term(reshape, term(view, generated_argument_name(buffer_idx), idxs), size(arrvar)))
3663
else
3764
elems = map(idxs) do idx
3865
i, j = idx
3966
term(getindex, generated_argument_name(i), j)
4067
end
68+
# use `MakeArray` and generate a stack-allocated array
4169
push!(assignments, arrvar MakeArray(elems, SArray))
4270
end
4371
end
4472

4573
return assignments
4674
end
4775

76+
"""
77+
$(TYPEDSIGNATURES)
78+
79+
A wrapper around `build_function` which performs the necessary transformations for
80+
code generation of all types of systems. `expr` is the expression returned from the
81+
generated functions, and `args` are the arguments.
82+
83+
# Keyword Arguments
84+
85+
- `p_start`, `p_end`: Denotes the indexes in `args` where the buffers of the splatted
86+
`MTKParameters` object are present. These are collapsed into a single argument and
87+
destructured inside the function. `p_start` must also be provided for non-split systems
88+
since it is used by `wrap_delays`.
89+
- `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
90+
calls to a history function. The history function is added to the list of arguments
91+
right before parameters, at the index `p_start`.
92+
- `wrap_code`: Forwarded to `build_function`.
93+
- `add_observed`: Whether to add assignment statements for observed equations in the
94+
generated code.
95+
- `filter_observed`: A predicate function to filter out observed equations which should
96+
not be added to the generated code.
97+
- `create_bindings`: Whether to explicitly destructure arrays of symbolics present in
98+
`args` in the generated code. If `false`, all usages of the individual symbolics will
99+
instead call `getindex` on the relevant argument. This is useful if the generated
100+
function writes to one of its arguments and expects subsequent code to use the new
101+
values. Note that the collapsed `MTKParameters` argument will always be explicitly
102+
destructured regardless of this keyword argument.
103+
- `output_type`: The type of the output buffer. If `mkarray` (see below) is `nothing`,
104+
this will be passed to the `similarto` argument of `build_function`. If `output_type`
105+
is `Tuple`, `expr` will be wrapped in `SymbolicUtils.Code.MakeTuple` (regardless of
106+
whether it is scalar or an array).
107+
- `mkarray`: A function which accepts `expr` and `output_type` and returns a code
108+
generation object similar to `MakeArray` or `MakeTuple` to be used to generate
109+
code for `expr`.
110+
- `wrap_mtkparameters`: Whether to collapse parameter buffers for a split system into a
111+
argument.
112+
113+
All other keyword arguments are forwarded to `build_function`.
114+
"""
48115
function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), wrap_delays = is_dde(sys), wrap_code = identity, add_observed = true, filter_observed = Returns(true), create_bindings = true, output_type = nothing, mkarray = nothing, wrap_mtkparameters = true, kwargs...)
49116
isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic())
50-
117+
# filter observed equations
51118
obs = filter(filter_observed, observed(sys))
119+
# turn delayed unknowns into calls to the history function
52120
if wrap_delays
53121
history_arg = is_split(sys) ? MTKPARAMETERS_ARG : generated_argument_name(p_start)
54122
obs = map(obs) do eq
55123
delay_to_function(sys, eq; history_arg)
56124
end
57125
expr = delay_to_function(sys, expr; history_arg)
58-
args = (args[1:p_start-1]..., DDE_HISTORY_FUN, args[p_start:end]...)
126+
# add extra argument
127+
args = (args[1:(p_start - 1)]..., DDE_HISTORY_FUN, args[p_start:end]...)
59128
p_start += 1
60129
p_end += 1
61130
end
62131
pdeps = parameter_dependencies(sys)
63-
132+
# get the constants to add to the code
64133
cmap, _ = get_cmap(sys)
65134
extra_constants = collect_constants(expr)
66135
filter!(extra_constants) do c
@@ -69,13 +138,15 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
69138
for c in extra_constants
70139
push!(cmap, c ~ getdefault(c))
71140
end
141+
# only get the necessary observed equations, avoiding extra computation
72142
if add_observed
73143
obsidxs = observed_equations_used_by(sys, expr)
74144
else
75145
obsidxs = Int[]
76146
end
147+
# similarly for parameter dependency equations
77148
pdepidxs = observed_equations_used_by(sys, expr; obs = pdeps)
78-
149+
# assignments for reconstructing scalarized array symbolics
79150
assignments = array_variable_assignments(args...)
80151

81152
for eq in Iterators.flatten((cmap, pdeps[pdepidxs], obs[obsidxs]))
@@ -84,6 +155,9 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
84155

85156
args = ntuple(Val(length(args))) do i
86157
arg = args[i]
158+
# for time-dependent systems, all arguments are passed through `time_varying_as_func`
159+
# TODO: This is legacy behavior and a candidate for removal in v10 since we have callable
160+
# parameters now.
87161
if is_time_dependent(sys)
88162
arg = if symbolic_type(arg) == NotSymbolic()
89163
arg isa AbstractArray ?
@@ -92,16 +166,19 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
92166
time_varying_as_func(unwrap(arg), sys)
93167
end
94168
end
169+
# Make sure to use the proper names for arguments
95170
if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
96171
DestructuredArgs(arg, generated_argument_name(i); create_bindings)
97172
else
98173
arg
99174
end
100175
end
101176

177+
# wrap into a single MTKParameters argument
102178
if is_split(sys) && wrap_mtkparameters
103179
if p_start > p_end
104-
args = (args[1:p_start-1]..., MTKPARAMETERS_ARG, args[p_end+1:end]...)
180+
# In case there are no parameter buffers, still insert an argument
181+
args = (args[1:(p_start - 1)]..., MTKPARAMETERS_ARG, args[(p_end + 1):end]...)
105182
else
106183
# cannot apply `create_bindings` here since it doesn't nest
107184
args = (args[1:(p_start - 1)]...,
@@ -110,12 +187,14 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
110187
end
111188
end
112189

190+
# add preface assignments
113191
if has_preface(sys) && (pref = preface(sys)) !== nothing
114192
append!(assignments, pref)
115193
end
116194

117195
wrap_code = wrap_code .∘ wrap_assignments(isscalar, assignments)
118196

197+
# handling of `output_type` and `mkarray`
119198
similarto = nothing
120199
if output_type === Tuple
121200
expr = MakeTuple(Tuple(expr))
@@ -127,6 +206,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
127206
wrap_code = wrap_code[2]
128207
end
129208

209+
# scalar `build_function` only accepts a single function for `wrap_code`.
130210
if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic()
131211
wrap_code = wrap_code[1]
132212
end

0 commit comments

Comments
 (0)