@@ -21,7 +21,7 @@ See [`getu`](@ref) and [`setu`](@ref) for further details.
2121
2222See also: [`associated_systems`](@ref).
2323"""
24- struct BatchedInterface{S <: AbstractVector , I, T}
24+ struct BatchedInterface{S <: AbstractVector , I, T, P }
2525 " Order of symbols in the union."
2626 symbol_order:: S
2727 " Index of the index provider each symbol in the union is associated with."
@@ -36,6 +36,8 @@ struct BatchedInterface{S <: AbstractVector, I, T}
3636 system_to_symbol_indexes:: Vector{Vector{T}}
3737 " Map from index provider to whether each of its symbols is a state in the index provider."
3838 system_to_isstate:: Vector{BitVector}
39+ " Index providers, in order"
40+ index_providers:: Vector{P}
3941end
4042
4143function BatchedInterface (syssyms:: Tuple... )
@@ -46,11 +48,18 @@ function BatchedInterface(syssyms::Tuple...)
4648 system_to_symbol_subset = Vector{Int}[]
4749 system_to_symbol_indexes = []
4850 system_to_isstate = BitVector[]
51+ index_providers = []
4952 for (i, (sys, syms)) in enumerate (syssyms)
5053 symbol_subset = Int[]
5154 symbol_indexes = []
5255 system_isstate = BitVector ()
5356 allsyms = []
57+ root_indp = sys
58+ while applicable (symbolic_container, root_indp) &&
59+ (sc = symbolic_container (root_indp)) != root_indp
60+ root_indp = sc
61+ end
62+ push! (index_providers, root_indp)
5463 for sym in syms
5564 if symbolic_type (sym) === NotSymbolic ()
5665 error (" Only symbolic variables allowed in BatchedInterface." )
@@ -89,9 +98,10 @@ function BatchedInterface(syssyms::Tuple...)
8998 system_to_symbol_indexes = identity .(system_to_symbol_indexes)
9099
91100 return BatchedInterface{typeof (symbol_order), typeof (associated_indexes),
92- eltype (eltype (system_to_symbol_indexes))}(
101+ eltype (eltype (system_to_symbol_indexes)), eltype (index_providers) }(
93102 symbol_order, associated_systems, associated_indexes, isstate,
94- system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
103+ system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate,
104+ identity .(index_providers))
95105end
96106
97107variable_symbols (bi:: BatchedInterface ) = bi. symbol_order
@@ -268,3 +278,102 @@ function setu(bi::BatchedInterface)
268278 setter!
269279 end
270280end
281+
282+ """
283+ setsym_oop(bi::BatchedInterface)
284+
285+ Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding
286+ symbols), return a function which takes `n` corresponding value providers and an array of
287+ values, and returns an `n`-tuple where each element is a 2-tuple consisting of the updated
288+ state values and parameter values of the corresponding value provider. Requires that the
289+ value provider implement [`state_values`](@ref), [`parameter_values`](@ref). The updates are
290+ performed out-of-place using [`remake_buffer`](@ref).
291+
292+ Note that all of the value providers passed to the returned function must satisfy
293+ `is_timeseries(prob) === NotTimeseries()`.
294+
295+ Note that if any subset of the `n` index providers share common symbols (among those passed
296+ to `BatchedInterface`) then all of the corresponding value providers in the subset will be
297+ updated with the values of the common symbols.
298+
299+ See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
300+ """
301+ function setsym_oop (bi:: BatchedInterface )
302+ numprobs = length (bi. system_to_symbol_subset)
303+ probnames = [Symbol (:prob , i) for i in 1 : numprobs]
304+ arg = :vals
305+ full_update = Expr (:block )
306+
307+ function get_update_expr (prob:: Symbol , sys_i:: Int )
308+ union_idxs = bi. system_to_symbol_subset[sys_i]
309+ indp_idxs = bi. system_to_symbol_indexes[sys_i]
310+ isstate = bi. system_to_isstate[sys_i]
311+ indp = bi. index_providers[sys_i]
312+ curexpr = Expr (:block )
313+
314+ statessym = Symbol (:states_ , sys_i)
315+ if all (.! isstate)
316+ push! (curexpr. args, :($ statessym = $ state_values ($ prob)))
317+ else
318+ state_idxssym = Symbol (:state_idxs_ , sys_i)
319+ state_idxs = indp_idxs[isstate]
320+ state_valssym = Symbol (:state_vals_ , sys_i)
321+ vals_idxs = union_idxs[isstate]
322+ push! (curexpr. args, :($ state_idxssym = $ state_idxs))
323+ push! (curexpr. args, :($ state_valssym = $ view ($ arg, $ vals_idxs)))
324+ push! (curexpr. args,
325+ :($ statessym = $ remake_buffer (
326+ syss[$ sys_i], $ state_values ($ prob), $ state_idxssym, $ state_valssym)))
327+ end
328+
329+ paramssym = Symbol (:params_ , sys_i)
330+ if all (isstate)
331+ push! (curexpr. args, :($ paramssym = $ parameter_values ($ prob)))
332+ else
333+ param_idxssym = Symbol (:param_idxs_ , sys_i)
334+ param_idxs = indp_idxs[.! isstate]
335+ param_valssym = Symbol (:param_vals , sys_i)
336+ vals_idxs = union_idxs[.! isstate]
337+ push! (curexpr. args, :($ param_idxssym = $ param_idxs))
338+ push! (curexpr. args, :($ param_valssym = $ view ($ arg, $ vals_idxs)))
339+ push! (curexpr. args,
340+ :($ paramssym = $ remake_buffer (
341+ syss[$ sys_i], $ parameter_values ($ prob), $ param_idxssym, $ param_valssym)))
342+ end
343+
344+ return curexpr, statessym, paramssym
345+ end
346+
347+ full_update_expr = Expr (:block )
348+ full_update_retval = Expr (:tuple )
349+ partial_update_expr = Expr (:block )
350+ cur_partial_update_expr = partial_update_expr
351+ for i in 1 : numprobs
352+ update_expr, statesym, paramsym = get_update_expr (probnames[i], i)
353+ push! (full_update_expr. args, update_expr)
354+ push! (full_update_retval. args, Expr (:tuple , statesym, paramsym))
355+
356+ cur_ifexpr = Expr (i == 1 ? :if : :elseif , :(idx == $ i))
357+ update_expr, statesym, paramsym = get_update_expr (:prob , i)
358+ push! (update_expr. args, :(return ($ statesym, $ paramsym)))
359+ push! (cur_ifexpr. args, update_expr)
360+ push! (cur_partial_update_expr. args, cur_ifexpr)
361+ cur_partial_update_expr = cur_ifexpr
362+ end
363+ push! (full_update_expr. args, :(return $ full_update_retval))
364+ push! (cur_partial_update_expr. args, :(error (" Invalid problem index $idx " )))
365+
366+ full_update_fnexpr = Expr (
367+ :function , Expr (:tuple , :syss , probnames... , arg), full_update_expr)
368+ partial_update_fnexpr = Expr (
369+ :function , Expr (:tuple , :syss , :prob , :idx , arg), partial_update_expr)
370+
371+ return let full_update = @RuntimeGeneratedFunction (full_update_fnexpr),
372+ partial_update = @RuntimeGeneratedFunction (partial_update_fnexpr),
373+ syss = Tuple (bi. index_providers)
374+
375+ setter (args... ) = full_update (syss, args... )
376+ setter (prob, idx:: Int , vals:: AbstractVector ) = partial_update (syss, prob, idx, vals)
377+ setter
378+ end
379+ end
0 commit comments