1+ """
2+ $(TYPEDSIGNATURES)
3+
4+ Return the name for the `i`th argument in a function generated by `build_function_wrapper`.
5+ """
16function generated_argument_name (i:: Int )
27 return Symbol (:__mtk_arg_ , i)
38end
49
10+ """
11+ $(TYPEDSIGNATURES)
12+
13+ Given the arguments to `build_function_wrapper`, return a list of assignments which
14+ reconstruct array variables if they are present scalarized in `args`.
15+ """
516function array_variable_assignments (args... )
17+ # map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
618 var_to_arridxs = Dict {BasicSymbolic, Array{Tuple{Int, Int}}} ()
719 for (i, arg) in enumerate (args)
20+ # filter out non-arrays
21+ # any element of args which is not an array is assumed to not contain a
22+ # scalarized array symbolic. This works because the only non-array element
23+ # is the independent variable
824 symbolic_type (arg) == NotSymbolic () || continue
925 arg isa AbstractArray || continue
1026
27+ # go through symbolics
1128 for (j, var) in enumerate (arg)
1229 var = unwrap (var)
30+ # filter out non-array-symbolics
1331 iscall (var) || continue
1432 operation (var) == getindex || continue
1533 arrvar = arguments (var)[1 ]
34+ # get and/or construct the buffer storing indexes
1635 idxbuffer = get! (() -> map (Returns ((0 , 0 )), eachindex (arrvar)), var_to_arridxs, arrvar)
1736 idxbuffer[arguments (var)[2 : end ]. .. ] = (i, j)
1837 end
1938 end
2039
2140 assignments = Assignment[]
2241 for (arrvar, idxs) in var_to_arridxs
42+ # all elements of the array need to be present in `args` to form the
43+ # reconstructing assignment
2344 any (iszero ∘ first, idxs) && continue
2445
46+ # if they are all in the same buffer, we can take a shortcut and `view` into it
2547 if allequal (Iterators. map (first, idxs))
2648 buffer_idx = first (first (idxs))
2749 idxs = map (last, idxs)
50+ # if all the elements are contiguous and ordered, turn the array of indexes into a range
51+ # to help reduce allocations
2852 if first (idxs) < last (idxs) && vec (idxs) == first (idxs): last (idxs)
2953 idxs = first (idxs): last (idxs)
3054 elseif vec (idxs) == last (idxs): - 1 : first (idxs)
3155 idxs = last (idxs): - 1 : first (idxs)
3256 else
57+ # Otherwise, turn the indexes into an `SArray` so they're stack-allocated
3358 idxs = SArray {Tuple{size(idxs)...}} (idxs)
3459 end
60+ # view and reshape
3561 push! (assignments, arrvar ← term (reshape, term (view, generated_argument_name (buffer_idx), idxs), size (arrvar)))
3662 else
3763 elems = map (idxs) do idx
3864 i, j = idx
3965 term (getindex, generated_argument_name (i), j)
4066 end
67+ # use `MakeArray` and generate a stack-allocated array
4168 push! (assignments, arrvar ← MakeArray (elems, SArray))
4269 end
4370 end
4471
4572 return assignments
4673end
4774
75+ """
76+ $(TYPEDSIGNATURES)
77+
78+ A wrapper around `build_function` which performs the necessary transformations for
79+ code generation of all types of systems. `expr` is the expression returned from the
80+ generated functions, and `args` are the arguments.
81+
82+ # Keyword Arguments
83+
84+ - `p_start`, `p_end`: Denotes the indexes in `args` where the buffers of the splatted
85+ `MTKParameters` object are present. These are collapsed into a single argument and
86+ destructured inside the function. `p_start` must also be provided for non-split systems
87+ since it is used by `wrap_delays`.
88+ - `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
89+ calls to a history function. The history function is added to the list of arguments
90+ right before parameters, at the index `p_start`.
91+ - `wrap_code`: Forwarded to `build_function`.
92+ - `add_observed`: Whether to add assignment statements for observed equations in the
93+ generated code.
94+ - `filter_observed`: A predicate function to filter out observed equations which should
95+ not be added to the generated code.
96+ - `create_bindings`: Whether to explicitly destructure arrays of symbolics present in
97+ `args` in the generated code. If `false`, all usages of the individual symbolics will
98+ instead call `getindex` on the relevant argument. This is useful if the generated
99+ function writes to one of its arguments and expects subsequent code to use the new
100+ values. Note that the collapsed `MTKParameters` argument will always be explicitly
101+ destructured regardless of this keyword argument.
102+ - `output_type`: The type of the output buffer. If `mkarray` (see below) is `nothing`,
103+ this will be passed to the `similarto` argument of `build_function`. If `output_type`
104+ is `Tuple`, `expr` will be wrapped in `SymbolicUtils.Code.MakeTuple` (regardless of
105+ whether it is scalar or an array).
106+ - `mkarray`: A function which accepts `expr` and `output_type` and returns a code
107+ generation object similar to `MakeArray` or `MakeTuple` to be used to generate
108+ code for `expr`.
109+ - `wrap_mtkparameters`: Whether to collapse parameter buffers for a split system into a
110+ argument.
111+
112+ All other keyword arguments are forwarded to `build_function`.
113+ """
48114function build_function_wrapper (sys:: AbstractSystem , expr, args... ; p_start = 2 , p_end = is_time_dependent (sys) ? length (args) - 1 : length (args), wrap_delays = is_dde (sys), wrap_code = identity, add_observed = true , filter_observed = Returns (true ), create_bindings = true , output_type = nothing , mkarray = nothing , wrap_mtkparameters = true , kwargs... )
49115 isscalar = ! (expr isa AbstractArray || symbolic_type (expr) == ArraySymbolic ())
50-
116+ # filter observed equations
51117 obs = filter (filter_observed, observed (sys))
118+ # turn delayed unknowns into calls to the history function
52119 if wrap_delays
53120 history_arg = is_split (sys) ? MTKPARAMETERS_ARG : generated_argument_name (p_start)
54121 obs = map (obs) do eq
55122 delay_to_function (sys, eq; history_arg)
56123 end
57124 expr = delay_to_function (sys, expr; history_arg)
125+ # add extra argument
58126 args = (args[1 : p_start- 1 ]. .. , DDE_HISTORY_FUN, args[p_start: end ]. .. )
59127 p_start += 1
60128 p_end += 1
61129 end
62130 pdeps = parameter_dependencies (sys)
63-
131+ # get the constants to add to the code
64132 cmap, _ = get_cmap (sys)
65133 extra_constants = collect_constants (expr)
66134 filter! (extra_constants) do c
@@ -69,13 +137,15 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
69137 for c in extra_constants
70138 push! (cmap, c ~ getdefault (c))
71139 end
140+ # only get the necessary observed equations, avoiding extra computation
72141 if add_observed
73142 obsidxs = observed_equations_used_by (sys, expr)
74143 else
75144 obsidxs = Int[]
76145 end
146+ # similarly for parameter dependency equations
77147 pdepidxs = observed_equations_used_by (sys, expr; obs = pdeps)
78-
148+ # assignments for reconstructing scalarized array symbolics
79149 assignments = array_variable_assignments (args... )
80150
81151 for eq in Iterators. flatten ((cmap, pdeps[pdepidxs], obs[obsidxs]))
@@ -84,35 +154,43 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
84154
85155 args = ntuple (Val (length (args))) do i
86156 arg = args[i]
157+ # for time-dependent systems, all arguments are passed through `time_varying_as_func`
158+ # TODO : This is legacy behavior and a candidate for removal in v10 since we have callable
159+ # parameters now.
87160 if is_time_dependent (sys)
88161 arg = if symbolic_type (arg) == NotSymbolic ()
89162 arg isa AbstractArray ? map (x -> time_varying_as_func (unwrap (x), sys), arg) : arg
90163 else
91164 time_varying_as_func (unwrap (arg), sys)
92165 end
93166 end
167+ # Make sure to use the proper names for arguments
94168 if symbolic_type (arg) == NotSymbolic () && arg isa AbstractArray
95169 DestructuredArgs (arg, generated_argument_name (i); create_bindings)
96170 else
97171 arg
98172 end
99173 end
100174
175+ # wrap into a single MTKParameters argument
101176 if is_split (sys) && wrap_mtkparameters
102177 if p_start > p_end
178+ # In case there are no parameter buffers, still insert an argument
103179 args = (args[1 : p_start- 1 ]. .. , MTKPARAMETERS_ARG, args[p_end+ 1 : end ]. .. )
104180 else
105181 # cannot apply `create_bindings` here since it doesn't nest
106182 args = (args[1 : p_start- 1 ]. .. , DestructuredArgs (collect (args[p_start: p_end]), MTKPARAMETERS_ARG), args[p_end+ 1 : end ]. .. )
107183 end
108184 end
109185
186+ # add preface assignments
110187 if has_preface (sys) && (pref = preface (sys)) != = nothing
111188 append! (assignments, pref)
112189 end
113190
114191 wrap_code = wrap_code .∘ wrap_assignments (isscalar, assignments)
115192
193+ # handling of `output_type` and `mkarray`
116194 similarto = nothing
117195 if output_type === Tuple
118196 expr = MakeTuple (Tuple (expr))
@@ -124,6 +202,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
124202 wrap_code = wrap_code[2 ]
125203 end
126204
205+ # scalar `build_function` only accepts a single function for `wrap_code`.
127206 if wrap_code isa Tuple && symbolic_type (expr) == ScalarSymbolic ()
128207 wrap_code = wrap_code[1 ]
129208 end
0 commit comments