@@ -206,217 +206,8 @@ function wrap_assignments(isscalar, assignments; let_block = false)
206206 end
207207end
208208
209- function wrap_parameter_dependencies (sys:: AbstractSystem , isscalar)
210- wrap_assignments (isscalar, [eq. lhs ← eq. rhs for eq in parameter_dependencies (sys)])
211- end
212-
213- """
214- $(TYPEDSIGNATURES)
215-
216- Add the necessary assignment statements to allow use of unscalarized array variables
217- in the generated code. `expr` is the expression returned by the function. `dvs` and
218- `ps` are the unknowns and parameters of the system `sys` to use in the generated code.
219- `inputs` can be specified as an array of symbolics if the generated function has inputs.
220- If `history == true`, the generated function accepts a history function. `cachesyms` are
221- extra variables (arrays of variables) stored in the cache array(s) of the parameter
222- object. `extra_args` are extra arguments appended to the end of the argument list.
223-
224- The function is assumed to have the signature `f(du, u, h, x, p, cache_syms..., t, extra_args...)`
225- Where:
226- - `du` is the optional buffer to write to for in-place functions.
227- - `u` is the list of unknowns. This argument is not present if `dvs === nothing`.
228- - `h` is the optional history function, present if `history == true`.
229- - `x` is the array of inputs, present only if `inputs !== nothing`. Values are assumed
230- to be in the order of variables passed to `inputs`.
231- - `p` is the parameter object.
232- - `cache_syms` are the cache variables. These are part of the splatted parameter object.
233- - `t` is time, present only if the system is time dependent.
234- - `extra_args` are the extra arguments passed to the function, present only if
235- `extra_args` is non-empty.
236- """
237- function wrap_array_vars (
238- sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys),
239- inputs = nothing , history = false , cachesyms:: Tuple = (), extra_args:: Tuple = ())
240- isscalar = ! (exprs isa AbstractArray)
241- var_to_arridxs = Dict ()
242-
243- if dvs === nothing
244- uind = 0
245- else
246- uind = 1
247- for (i, x) in enumerate (dvs)
248- iscall (x) && operation (x) == getindex || continue
249- arg = arguments (x)[1 ]
250- inds = get! (() -> [], var_to_arridxs, arg)
251- push! (inds, (uind, i))
252- end
253- end
254- p_start = uind + 1 + history
255- rps = (reorder_parameters (sys, ps)... , cachesyms... )
256- if inputs != = nothing
257- rps = (inputs, rps... )
258- end
259- if has_iv (sys)
260- rps = (rps... , get_iv (sys))
261- end
262- rps = (rps... , extra_args... )
263- for sym in reduce (vcat, rps; init = [])
264- iscall (sym) && operation (sym) == getindex || continue
265- arg = arguments (sym)[1 ]
266-
267- bufferidx = findfirst (buf -> any (isequal (sym), buf), rps)
268- idxinbuffer = findfirst (isequal (sym), rps[bufferidx])
269- inds = get! (() -> [], var_to_arridxs, arg)
270- push! (inds, (p_start + bufferidx - 1 , idxinbuffer))
271- end
272-
273- viewsyms = Dict ()
274- splitsyms = Dict ()
275- for (arrsym, idxs) in var_to_arridxs
276- length (idxs) == length (arrsym) || continue
277- # allequal(first, idxs) is a 1.11 feature
278- if allequal (Iterators. map (first, idxs))
279- viewsyms[arrsym] = (first (first (idxs)), reshape (last .(idxs), size (arrsym)))
280- else
281- splitsyms[arrsym] = reshape (idxs, size (arrsym))
282- end
283- end
284- if isscalar
285- function (expr)
286- Func (
287- expr. args,
288- [],
289- Let (
290- vcat (
291- [sym ← :(view ($ (expr. args[i]. name), $ idxs))
292- for (sym, (i, idxs)) in viewsyms],
293- [sym ←
294- MakeArray ([expr. args[bufi]. elems[vali] for (bufi, vali) in idxs],
295- expr. args[idxs[1 ][1 ]]) for (sym, idxs) in splitsyms]
296- ),
297- expr. body,
298- false
299- )
300- )
301- end
302- else
303- function (expr)
304- Func (
305- expr. args,
306- [],
307- Let (
308- vcat (
309- [sym ← :(view ($ (expr. args[i]. name), $ idxs))
310- for (sym, (i, idxs)) in viewsyms],
311- [sym ←
312- MakeArray ([expr. args[bufi]. elems[vali] for (bufi, vali) in idxs],
313- expr. args[idxs[1 ][1 ]]) for (sym, idxs) in splitsyms]
314- ),
315- expr. body,
316- false
317- )
318- )
319- end ,
320- function (expr)
321- Func (
322- expr. args,
323- [],
324- Let (
325- vcat (
326- [sym ← :(view ($ (expr. args[i + 1 ]. name), $ idxs))
327- for (sym, (i, idxs)) in viewsyms],
328- [sym ← MakeArray (
329- [expr. args[bufi + 1 ]. elems[vali] for (bufi, vali) in idxs],
330- expr. args[idxs[1 ][1 ] + 1 ]) for (sym, idxs) in splitsyms]
331- ),
332- expr. body,
333- false
334- )
335- )
336- end
337- end
338- end
339-
340209const MTKPARAMETERS_ARG = Sym {Vector{Vector}} (:___mtkparameters___ )
341210
342- """
343- wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2, offset = Int(is_time_dependent(sys)))
344-
345- Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
346- allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
347- instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
348- wrapped is for a scalar value. `p_start` is the index of the argument containing
349- the first parameter vector in the out-of-place version of the function. For example,
350- if a history function (DDEs) was passed before `p`, then the function before wrapping
351- would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
352-
353- `offset` is the number of arguments at the end of the argument list to ignore. Defaults
354- to 1 if the system is time-dependent (to ignore `t`) and 0 otherwise.
355-
356- The returned function is `identity` if the system does not have an `IndexCache`.
357- """
358- function wrap_mtkparameters (sys:: AbstractSystem , isscalar:: Bool , p_start = 2 ,
359- offset = Int (is_time_dependent (sys)))
360- if has_index_cache (sys) && get_index_cache (sys) != = nothing
361- if isscalar
362- function (expr)
363- param_args = expr. args[p_start: (end - offset)]
364- param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
365- param_buffer_args = param_args[param_buffer_idxs]
366- destructured_mtkparams = DestructuredArgs (
367- [x. name for x in param_buffer_args],
368- MTKPARAMETERS_ARG; inds = param_buffer_idxs)
369- Func (
370- [
371- expr. args[begin : (p_start - 1 )]. .. ,
372- destructured_mtkparams,
373- expr. args[(end - offset + 1 ): end ]. ..
374- ],
375- [],
376- Let (param_buffer_args, expr. body, false )
377- )
378- end
379- else
380- function (expr)
381- param_args = expr. args[p_start: (end - offset)]
382- param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
383- param_buffer_args = param_args[param_buffer_idxs]
384- destructured_mtkparams = DestructuredArgs (
385- [x. name for x in param_buffer_args],
386- MTKPARAMETERS_ARG; inds = param_buffer_idxs)
387- Func (
388- [
389- expr. args[begin : (p_start - 1 )]. .. ,
390- destructured_mtkparams,
391- expr. args[(end - offset + 1 ): end ]. ..
392- ],
393- [],
394- Let (param_buffer_args, expr. body, false )
395- )
396- end ,
397- function (expr)
398- param_args = expr. args[(p_start + 1 ): (end - offset)]
399- param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
400- param_buffer_args = param_args[param_buffer_idxs]
401- destructured_mtkparams = DestructuredArgs (
402- [x. name for x in param_buffer_args],
403- MTKPARAMETERS_ARG; inds = param_buffer_idxs)
404- Func (
405- [
406- expr. args[begin : p_start]. .. ,
407- destructured_mtkparams,
408- expr. args[(end - offset + 1 ): end ]. ..
409- ],
410- [],
411- Let (param_buffer_args, expr. body, false )
412- )
413- end
414- end
415- else
416- identity
417- end
418- end
419-
420211mutable struct Substitutions
421212 subs:: Vector{Equation}
422213 deps:: Vector{Vector{Int}}
0 commit comments