Skip to content

Commit a8a7d4f

Browse files
feat: add setsym_oop for BatchedInterface
1 parent 742dc91 commit a8a7d4f

File tree

5 files changed

+175
-4
lines changed

5 files changed

+175
-4
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ with_updated_parameter_timeseries_values
112112
```@docs
113113
BatchedInterface
114114
associated_systems
115+
setsym_oop
115116
```
116117

117118
## Container objects

src/SymbolicIndexingInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ include("parameter_indexing.jl")
3838
export getu, setu
3939
include("state_indexing.jl")
4040

41-
export BatchedInterface, associated_systems
41+
export BatchedInterface, setsym_oop, associated_systems
4242
include("batched_interface.jl")
4343

4444
export ProblemState

src/batched_interface.jl

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ See [`getu`](@ref) and [`setu`](@ref) for further details.
2121
2222
See also: [`associated_systems`](@ref).
2323
"""
24-
struct BatchedInterface{S <: AbstractVector, I, T}
24+
struct BatchedInterface{S <: AbstractVector, I, T, P}
2525
"Order of symbols in the union."
2626
symbol_order::S
2727
"Index of the index provider each symbol in the union is associated with."
@@ -36,6 +36,8 @@ struct BatchedInterface{S <: AbstractVector, I, T}
3636
system_to_symbol_indexes::Vector{Vector{T}}
3737
"Map from index provider to whether each of its symbols is a state in the index provider."
3838
system_to_isstate::Vector{BitVector}
39+
"Index providers, in order"
40+
index_providers::Vector{P}
3941
end
4042

4143
function BatchedInterface(syssyms::Tuple...)
@@ -46,11 +48,18 @@ function BatchedInterface(syssyms::Tuple...)
4648
system_to_symbol_subset = Vector{Int}[]
4749
system_to_symbol_indexes = []
4850
system_to_isstate = BitVector[]
51+
index_providers = []
4952
for (i, (sys, syms)) in enumerate(syssyms)
5053
symbol_subset = Int[]
5154
symbol_indexes = []
5255
system_isstate = BitVector()
5356
allsyms = []
57+
root_indp = sys
58+
while applicable(symbolic_container, root_indp) &&
59+
(sc = symbolic_container(root_indp)) != root_indp
60+
root_indp = sc
61+
end
62+
push!(index_providers, root_indp)
5463
for sym in syms
5564
if symbolic_type(sym) === NotSymbolic()
5665
error("Only symbolic variables allowed in BatchedInterface.")
@@ -89,9 +98,10 @@ function BatchedInterface(syssyms::Tuple...)
8998
system_to_symbol_indexes = identity.(system_to_symbol_indexes)
9099

91100
return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),
92-
eltype(eltype(system_to_symbol_indexes))}(
101+
eltype(eltype(system_to_symbol_indexes)), eltype(index_providers)}(
93102
symbol_order, associated_systems, associated_indexes, isstate,
94-
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
103+
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate,
104+
identity.(index_providers))
95105
end
96106

97107
variable_symbols(bi::BatchedInterface) = bi.symbol_order
@@ -268,3 +278,102 @@ function setu(bi::BatchedInterface)
268278
setter!
269279
end
270280
end
281+
282+
"""
283+
setsym_oop(bi::BatchedInterface)
284+
285+
Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding
286+
symbols), return a function which takes `n` corresponding value providers and an array of
287+
values, and returns an `n`-tuple where each element is a 2-tuple consisting of the updated
288+
state values and parameter values of the corresponding value provider. Requires that the
289+
value provider implement [`state_values`](@ref), [`parameter_values`](@ref). The updates are
290+
performed out-of-place using [`remake_buffer`](@ref).
291+
292+
Note that all of the value providers passed to the returned function must satisfy
293+
`is_timeseries(prob) === NotTimeseries()`.
294+
295+
Note that if any subset of the `n` index providers share common symbols (among those passed
296+
to `BatchedInterface`) then all of the corresponding value providers in the subset will be
297+
updated with the values of the common symbols.
298+
299+
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
300+
"""
301+
function setsym_oop(bi::BatchedInterface)
302+
numprobs = length(bi.system_to_symbol_subset)
303+
probnames = [Symbol(:prob, i) for i in 1:numprobs]
304+
arg = :vals
305+
full_update = Expr(:block)
306+
307+
function get_update_expr(prob::Symbol, sys_i::Int)
308+
union_idxs = bi.system_to_symbol_subset[sys_i]
309+
indp_idxs = bi.system_to_symbol_indexes[sys_i]
310+
isstate = bi.system_to_isstate[sys_i]
311+
indp = bi.index_providers[sys_i]
312+
curexpr = Expr(:block)
313+
314+
statessym = Symbol(:states_, sys_i)
315+
if all(.!isstate)
316+
push!(curexpr.args, :($statessym = $state_values($prob)))
317+
else
318+
state_idxssym = Symbol(:state_idxs_, sys_i)
319+
state_idxs = indp_idxs[isstate]
320+
state_valssym = Symbol(:state_vals_, sys_i)
321+
vals_idxs = union_idxs[isstate]
322+
push!(curexpr.args, :($state_idxssym = $state_idxs))
323+
push!(curexpr.args, :($state_valssym = $view($arg, $vals_idxs)))
324+
push!(curexpr.args,
325+
:($statessym = $remake_buffer(
326+
syss[$sys_i], $state_values($prob), $state_idxssym, $state_valssym)))
327+
end
328+
329+
paramssym = Symbol(:params_, sys_i)
330+
if all(isstate)
331+
push!(curexpr.args, :($paramssym = $parameter_values($prob)))
332+
else
333+
param_idxssym = Symbol(:param_idxs_, sys_i)
334+
param_idxs = indp_idxs[.!isstate]
335+
param_valssym = Symbol(:param_vals, sys_i)
336+
vals_idxs = union_idxs[.!isstate]
337+
push!(curexpr.args, :($param_idxssym = $param_idxs))
338+
push!(curexpr.args, :($param_valssym = $view($arg, $vals_idxs)))
339+
push!(curexpr.args,
340+
:($paramssym = $remake_buffer(
341+
syss[$sys_i], $parameter_values($prob), $param_idxssym, $param_valssym)))
342+
end
343+
344+
return curexpr, statessym, paramssym
345+
end
346+
347+
full_update_expr = Expr(:block)
348+
full_update_retval = Expr(:tuple)
349+
partial_update_expr = Expr(:block)
350+
cur_partial_update_expr = partial_update_expr
351+
for i in 1:numprobs
352+
update_expr, statesym, paramsym = get_update_expr(probnames[i], i)
353+
push!(full_update_expr.args, update_expr)
354+
push!(full_update_retval.args, Expr(:tuple, statesym, paramsym))
355+
356+
cur_ifexpr = Expr(i == 1 ? :if : :elseif, :(idx == $i))
357+
update_expr, statesym, paramsym = get_update_expr(:prob, i)
358+
push!(update_expr.args, :(return ($statesym, $paramsym)))
359+
push!(cur_ifexpr.args, update_expr)
360+
push!(cur_partial_update_expr.args, cur_ifexpr)
361+
cur_partial_update_expr = cur_ifexpr
362+
end
363+
push!(full_update_expr.args, :(return $full_update_retval))
364+
push!(cur_partial_update_expr.args, :(error("Invalid problem index $idx")))
365+
366+
full_update_fnexpr = Expr(
367+
:function, Expr(:tuple, :syss, probnames..., arg), full_update_expr)
368+
partial_update_fnexpr = Expr(
369+
:function, Expr(:tuple, :syss, :prob, :idx, arg), partial_update_expr)
370+
371+
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
372+
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr),
373+
syss = Tuple(bi.index_providers)
374+
375+
setter(args...) = full_update(syss, args...)
376+
setter(prob, idx::Int, vals::AbstractVector) = partial_update(syss, prob, idx, vals)
377+
setter
378+
end
379+
end

test/batched_interface_test.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,31 @@ setter!(probs[1], 1, buf)
5454
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]
5555

5656
@test_throws ErrorException setter!(probs[1], 4, buf)
57+
58+
setter!(probs..., buf)
59+
60+
setter = setsym_oop(bi)
61+
62+
buf .*= 100
63+
vals = setter(probs..., buf)
64+
@test length(vals) == length(probs)
65+
@test vals[1][1] == [100.0, 2.0, 300.0]
66+
@test vals[1][2] == [0.1, 20.0, 30.0]
67+
@test vals[2][1] == [300.0, 500.0, 6.0]
68+
@test vals[2][2] == [30.0, 0.5, 60.0]
69+
@test vals[3][1] == [500.0, 100.0, 9.0]
70+
@test vals[3][2] == [70.0, 80.0, 0.9]
71+
72+
# out-of-place
73+
for i in 1:3
74+
@test vals[i][1] != state_values(probs[i])
75+
@test vals[i][2] != parameter_values(probs[i])
76+
end
77+
78+
buf ./= 10
79+
vals = setter(probs[1], 1, buf)
80+
@test length(vals) == 2
81+
@test vals[1] == [10.0, 2.0, 30.0]
82+
@test vals[2] == [0.1, 2.0, 3.0]
83+
84+
@test_throws ErrorException setter(probs[1], 4, buf)

test/downstream/batchedinterface_arrayvars.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,36 @@ buf ./= 10
4040

4141
setter!(probs[1], 1, buf)
4242
@test state_values(probs[1]) == [1.0, 2.0, 3.0]
43+
44+
@variables a b[1:2] c
45+
46+
syss = [
47+
SymbolCache([x..., y], [a, b...]),
48+
SymbolCache([x[1], y, z], [a, b..., c])
49+
]
50+
syms = [
51+
[x, y, a, b...],
52+
[x[1], y, b[2], c]
53+
]
54+
probs = [
55+
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
56+
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.1, 0.4, 0.5, 0.6])
57+
]
58+
59+
bi = BatchedInterface(zip(syss, syms)...)
60+
61+
buf = getu(bi)(probs...)
62+
buf .*= 100
63+
setter = setsym_oop(bi)
64+
vals = setter(probs..., buf)
65+
@test length(vals) == length(probs)
66+
@test vals[1][1] == [100.0, 200.0, 300.0]
67+
@test vals[1][2] == [10.0, 20.0, 30.0]
68+
@test vals[2][1] == [100.0, 300.0, 6.0]
69+
@test vals[2][2] == [0.1, 0.4, 30.0, 60.0]
70+
71+
buf ./= 10
72+
vals = setter(probs[1], 1, buf)
73+
@test length(vals) == 2
74+
@test vals[1] == [10.0, 20.0, 30.0]
75+
@test vals[2] == [1.0, 2.0, 3.0]

0 commit comments

Comments
 (0)