1+ """
2+ $(TYPEDSIGNATURES)
3+
4+ Return the name for the `i`th argument in a function generated by `build_function_wrapper`.
5+ """
16function generated_argument_name (i:: Int )
27 return Symbol (:__mtk_arg_ , i)
38end
49
10+ """
11+ $(TYPEDSIGNATURES)
12+
13+ Given the arguments to `build_function_wrapper`, return a list of assignments which
14+ reconstruct array variables if they are present scalarized in `args`.
15+ """
516function array_variable_assignments (args... )
17+ # map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
618 var_to_arridxs = Dict {BasicSymbolic, Array{Tuple{Int, Int}}} ()
719 for (i, arg) in enumerate (args)
20+ # filter out non-arrays
21+ # any element of args which is not an array is assumed to not contain a
22+ # scalarized array symbolic. This works because the only non-array element
23+ # is the independent variable
824 symbolic_type (arg) == NotSymbolic () || continue
925 arg isa AbstractArray || continue
1026
27+ # go through symbolics
1128 for (j, var) in enumerate (arg)
1229 var = unwrap (var)
30+ # filter out non-array-symbolics
1331 iscall (var) || continue
1432 operation (var) == getindex || continue
1533 arrvar = arguments (var)[1 ]
16- idxbuffer = get! (() -> map (Returns ((0 , 0 )), eachindex (arrvar)), var_to_arridxs, arrvar)
34+ # get and/or construct the buffer storing indexes
35+ idxbuffer = get! (
36+ () -> map (Returns ((0 , 0 )), eachindex (arrvar)), var_to_arridxs, arrvar)
1737 idxbuffer[arguments (var)[2 : end ]. .. ] = (i, j)
1838 end
1939 end
2040
2141 assignments = Assignment[]
2242 for (arrvar, idxs) in var_to_arridxs
43+ # all elements of the array need to be present in `args` to form the
44+ # reconstructing assignment
2345 any (iszero ∘ first, idxs) && continue
2446
47+ # if they are all in the same buffer, we can take a shortcut and `view` into it
2548 if allequal (Iterators. map (first, idxs))
2649 buffer_idx = first (first (idxs))
2750 idxs = map (last, idxs)
51+ # if all the elements are contiguous and ordered, turn the array of indexes into a range
52+ # to help reduce allocations
2853 if first (idxs) < last (idxs) && vec (idxs) == first (idxs): last (idxs)
2954 idxs = first (idxs): last (idxs)
3055 elseif vec (idxs) == last (idxs): - 1 : first (idxs)
3156 idxs = last (idxs): - 1 : first (idxs)
3257 else
58+ # Otherwise, turn the indexes into an `SArray` so they're stack-allocated
3359 idxs = SArray {Tuple{size(idxs)...}} (idxs)
3460 end
61+ # view and reshape
3562 push! (assignments, arrvar ← term (reshape, term (view, generated_argument_name (buffer_idx), idxs), size (arrvar)))
3663 else
3764 elems = map (idxs) do idx
3865 i, j = idx
3966 term (getindex, generated_argument_name (i), j)
4067 end
68+ # use `MakeArray` and generate a stack-allocated array
4169 push! (assignments, arrvar ← MakeArray (elems, SArray))
4270 end
4371 end
4472
4573 return assignments
4674end
4775
76+ """
77+ $(TYPEDSIGNATURES)
78+
79+ A wrapper around `build_function` which performs the necessary transformations for
80+ code generation of all types of systems. `expr` is the expression returned from the
81+ generated functions, and `args` are the arguments.
82+
83+ # Keyword Arguments
84+
85+ - `p_start`, `p_end`: Denotes the indexes in `args` where the buffers of the splatted
86+ `MTKParameters` object are present. These are collapsed into a single argument and
87+ destructured inside the function. `p_start` must also be provided for non-split systems
88+ since it is used by `wrap_delays`.
89+ - `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
90+ calls to a history function. The history function is added to the list of arguments
91+ right before parameters, at the index `p_start`.
92+ - `wrap_code`: Forwarded to `build_function`.
93+ - `add_observed`: Whether to add assignment statements for observed equations in the
94+ generated code.
95+ - `filter_observed`: A predicate function to filter out observed equations which should
96+ not be added to the generated code.
97+ - `create_bindings`: Whether to explicitly destructure arrays of symbolics present in
98+ `args` in the generated code. If `false`, all usages of the individual symbolics will
99+ instead call `getindex` on the relevant argument. This is useful if the generated
100+ function writes to one of its arguments and expects subsequent code to use the new
101+ values. Note that the collapsed `MTKParameters` argument will always be explicitly
102+ destructured regardless of this keyword argument.
103+ - `output_type`: The type of the output buffer. If `mkarray` (see below) is `nothing`,
104+ this will be passed to the `similarto` argument of `build_function`. If `output_type`
105+ is `Tuple`, `expr` will be wrapped in `SymbolicUtils.Code.MakeTuple` (regardless of
106+ whether it is scalar or an array).
107+ - `mkarray`: A function which accepts `expr` and `output_type` and returns a code
108+ generation object similar to `MakeArray` or `MakeTuple` to be used to generate
109+ code for `expr`.
110+ - `wrap_mtkparameters`: Whether to collapse parameter buffers for a split system into a
111+ argument.
112+
113+ All other keyword arguments are forwarded to `build_function`.
114+ """
48115function build_function_wrapper (sys:: AbstractSystem , expr, args... ; p_start = 2 , p_end = is_time_dependent (sys) ? length (args) - 1 : length (args), wrap_delays = is_dde (sys), wrap_code = identity, add_observed = true , filter_observed = Returns (true ), create_bindings = true , output_type = nothing , mkarray = nothing , wrap_mtkparameters = true , kwargs... )
49116 isscalar = ! (expr isa AbstractArray || symbolic_type (expr) == ArraySymbolic ())
50-
117+ # filter observed equations
51118 obs = filter (filter_observed, observed (sys))
119+ # turn delayed unknowns into calls to the history function
52120 if wrap_delays
53121 history_arg = is_split (sys) ? MTKPARAMETERS_ARG : generated_argument_name (p_start)
54122 obs = map (obs) do eq
55123 delay_to_function (sys, eq; history_arg)
56124 end
57125 expr = delay_to_function (sys, expr; history_arg)
58- args = (args[1 : p_start- 1 ]. .. , DDE_HISTORY_FUN, args[p_start: end ]. .. )
126+ # add extra argument
127+ args = (args[1 : (p_start - 1 )]. .. , DDE_HISTORY_FUN, args[p_start: end ]. .. )
59128 p_start += 1
60129 p_end += 1
61130 end
62131 pdeps = parameter_dependencies (sys)
63-
132+ # get the constants to add to the code
64133 cmap, _ = get_cmap (sys)
65134 extra_constants = collect_constants (expr)
66135 filter! (extra_constants) do c
@@ -69,13 +138,15 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
69138 for c in extra_constants
70139 push! (cmap, c ~ getdefault (c))
71140 end
141+ # only get the necessary observed equations, avoiding extra computation
72142 if add_observed
73143 obsidxs = observed_equations_used_by (sys, expr)
74144 else
75145 obsidxs = Int[]
76146 end
147+ # similarly for parameter dependency equations
77148 pdepidxs = observed_equations_used_by (sys, expr; obs = pdeps)
78-
149+ # assignments for reconstructing scalarized array symbolics
79150 assignments = array_variable_assignments (args... )
80151
81152 for eq in Iterators. flatten ((cmap, pdeps[pdepidxs], obs[obsidxs]))
@@ -84,6 +155,9 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
84155
85156 args = ntuple (Val (length (args))) do i
86157 arg = args[i]
158+ # for time-dependent systems, all arguments are passed through `time_varying_as_func`
159+ # TODO : This is legacy behavior and a candidate for removal in v10 since we have callable
160+ # parameters now.
87161 if is_time_dependent (sys)
88162 arg = if symbolic_type (arg) == NotSymbolic ()
89163 arg isa AbstractArray ?
@@ -92,16 +166,19 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
92166 time_varying_as_func (unwrap (arg), sys)
93167 end
94168 end
169+ # Make sure to use the proper names for arguments
95170 if symbolic_type (arg) == NotSymbolic () && arg isa AbstractArray
96171 DestructuredArgs (arg, generated_argument_name (i); create_bindings)
97172 else
98173 arg
99174 end
100175 end
101176
177+ # wrap into a single MTKParameters argument
102178 if is_split (sys) && wrap_mtkparameters
103179 if p_start > p_end
104- args = (args[1 : p_start- 1 ]. .. , MTKPARAMETERS_ARG, args[p_end+ 1 : end ]. .. )
180+ # In case there are no parameter buffers, still insert an argument
181+ args = (args[1 : (p_start - 1 )]. .. , MTKPARAMETERS_ARG, args[(p_end + 1 ): end ]. .. )
105182 else
106183 # cannot apply `create_bindings` here since it doesn't nest
107184 args = (args[1 : (p_start - 1 )]. .. ,
@@ -110,12 +187,14 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
110187 end
111188 end
112189
190+ # add preface assignments
113191 if has_preface (sys) && (pref = preface (sys)) != = nothing
114192 append! (assignments, pref)
115193 end
116194
117195 wrap_code = wrap_code .∘ wrap_assignments (isscalar, assignments)
118196
197+ # handling of `output_type` and `mkarray`
119198 similarto = nothing
120199 if output_type === Tuple
121200 expr = MakeTuple (Tuple (expr))
@@ -127,6 +206,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
127206 wrap_code = wrap_code[2 ]
128207 end
129208
209+ # scalar `build_function` only accepts a single function for `wrap_code`.
130210 if wrap_code isa Tuple && symbolic_type (expr) == ScalarSymbolic ()
131211 wrap_code = wrap_code[1 ]
132212 end
0 commit comments