Skip to content

Commit ac86106

Browse files
feat: add BatchedInterface
1 parent f43b850 commit ac86106

File tree

5 files changed

+278
-0
lines changed

5 files changed

+278
-0
lines changed

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,10 @@ symbolic_evaluate
9090
SymbolCache
9191
ProblemState
9292
```
93+
94+
### Batched Queries and Updates
95+
96+
```@docs
97+
BatchedInterface
98+
associated_systems
99+
```

src/SymbolicIndexingInterface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ include("parameter_indexing.jl")
3131
export state_values, set_state!, current_time, getu, setu
3232
include("state_indexing.jl")
3333

34+
export BatchedInterface, associated_systems
35+
include("batched_interface.jl")
36+
3437
export ProblemState
3538
include("problem_state.jl")
3639

src/batched_interface.jl

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""
2+
struct BatchedInterface{S <: AbstractVector, I}
3+
function BatchedInterface(syssyms::Tuple...)
4+
5+
A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref).
6+
Given `Tuple`s, where the first element of each tuple is a system and the second an
7+
array of symbols (either variables or parameters) in the system, `BatchedInterface` will
8+
compute the union of all symbols and associate each symbol with the first system with
9+
which it occurs.
10+
11+
For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and
12+
`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will
13+
associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had
14+
associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also
15+
be retained internally.
16+
17+
`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref),
18+
[`variable_index`](@ref) to query the order of symbols in the union.
19+
20+
See [`getu`](@ref) and [`setu`](@ref) for further details.
21+
22+
See also: [`associated_systems`](@ref).
23+
"""
24+
struct BatchedInterface{S <: AbstractVector, I, T}
25+
"Order of symbols in the union."
26+
symbol_order::S
27+
"Index of the system each symbol in the union is associated with."
28+
associated_systems::Vector{Int}
29+
"Index of symbol in the system it is associated with."
30+
associated_indexes::I
31+
"Whether the symbol is a state in the system it is associated with."
32+
isstate::BitVector
33+
"Map from system to indexes of its symbols in the union."
34+
system_to_symbol_subset::Vector{Vector{Int}}
35+
"Map from system to indexes of its symbols in the system."
36+
system_to_symbol_indexes::Vector{Vector{T}}
37+
"Map from system to whether each of its symbols is a state in the system."
38+
system_to_isstate::Vector{BitVector}
39+
end
40+
41+
function BatchedInterface(syssyms::Tuple...)
42+
symbol_order = []
43+
associated_systems = Int[]
44+
associated_indexes = []
45+
isstate = BitVector()
46+
system_to_symbol_subset = Vector{Int}[]
47+
system_to_symbol_indexes = []
48+
system_to_isstate = BitVector[]
49+
for (i, (sys, syms)) in enumerate(syssyms)
50+
symbol_subset = Int[]
51+
symbol_indexes = []
52+
system_isstate = BitVector()
53+
for sym in syms
54+
if symbolic_type(sym) === NotSymbolic()
55+
error("Only symbolic variables allowed in BatchedInterface.")
56+
end
57+
if !is_variable(sys, sym) && !is_parameter(sys, sym)
58+
error("Only variables and parameters allowed in BatchedInterface.")
59+
end
60+
if !any(isequal(sym), symbol_order)
61+
push!(symbol_order, sym)
62+
push!(associated_systems, i)
63+
push!(isstate, is_variable(sys, sym))
64+
if isstate[end]
65+
push!(associated_indexes, variable_index(sys, sym))
66+
else
67+
push!(associated_indexes, parameter_index(sys, sym))
68+
end
69+
end
70+
push!(symbol_subset, findfirst(isequal(sym), symbol_order))
71+
push!(system_isstate, is_variable(sys, sym))
72+
push!(symbol_indexes,
73+
system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym))
74+
end
75+
push!(system_to_symbol_subset, symbol_subset)
76+
push!(system_to_symbol_indexes, identity.(symbol_indexes))
77+
push!(system_to_isstate, system_isstate)
78+
end
79+
symbol_order = identity.(symbol_order)
80+
associated_indexes = identity.(associated_indexes)
81+
system_to_symbol_indexes = identity.(system_to_symbol_indexes)
82+
83+
return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),
84+
eltype(eltype(system_to_symbol_indexes))}(
85+
symbol_order, associated_systems, associated_indexes, isstate,
86+
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
87+
end
88+
89+
variable_symbols(bi::BatchedInterface) = bi.symbol_order
90+
variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order)
91+
is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing
92+
93+
"""
94+
associated_systems(bi::BatchedInterface)
95+
96+
Return an array of integers of the same length as `variable_symbols(bi)` where each value
97+
is the index of the system associated with the corresponding symbol in
98+
`variable_symbols(bi)`.
99+
"""
100+
associated_systems(bi::BatchedInterface) = bi.associated_systems
101+
102+
"""
103+
getu(bi::BatchedInterface)
104+
105+
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
106+
return a function which takes `n` corresponding problems and returns an array of the values
107+
of the symbols in the union. The returned function can also be passed an `AbstractArray` of
108+
the appropriate `eltype` and size as its first argument, in which case the operation will
109+
populate the array in-place with the values of the symbols in the union.
110+
111+
Note that all of the problems passed to the function returned by `getu` must satisfy
112+
`is_timeseries(prob) === NotTimeseries()`.
113+
114+
The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is
115+
obtained from the problem corresponding to the associated system (i.e. the problem at
116+
index `associated_systems(bi)[i]`).
117+
118+
See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref),
119+
[`NotTimeseries`](@ref).
120+
"""
121+
function getu(bi::BatchedInterface)
122+
numprobs = length(bi.system_to_symbol_subset)
123+
probnames = [Symbol(:prob, i) for i in 1:numprobs]
124+
125+
fnbody = quote end
126+
for (i, (prob, idx, isstate)) in enumerate(zip(
127+
bi.associated_systems, bi.associated_indexes, bi.isstate))
128+
symname = Symbol(:sym, i)
129+
getter = isstate ? state_values : parameter_values
130+
probname = probnames[prob]
131+
push!(fnbody.args, :($symname = $getter($probname, $idx)))
132+
end
133+
134+
oop_expr = Expr(:vect)
135+
for i in 1:length(bi.symbol_order)
136+
push!(oop_expr.args, Symbol(:sym, i))
137+
end
138+
139+
iip_expr = quote end
140+
for i in 1:length(bi.symbol_order)
141+
symname = Symbol(:sym, i)
142+
push!(iip_expr.args, :(out[$i] = $symname))
143+
end
144+
145+
oopfn = Expr(
146+
:function,
147+
Expr(:tuple, probnames...),
148+
quote
149+
$fnbody
150+
$oop_expr
151+
end
152+
)
153+
iipfn = Expr(
154+
:function,
155+
Expr(:tuple, :out, probnames...),
156+
quote
157+
$fnbody
158+
$iip_expr
159+
out
160+
end
161+
)
162+
163+
return let oop = @RuntimeGeneratedFunction(oopfn),
164+
iip = @RuntimeGeneratedFunction(iipfn)
165+
166+
_getter(probs...) = oop(probs...)
167+
_getter(out::AbstractArray, probs...) = iip(out, probs...)
168+
_getter
169+
end
170+
end
171+
172+
"""
173+
setu(bi::BatchedInterface)
174+
175+
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
176+
return a function which takes `n` corresponding problems and an array of the values, and
177+
updates each of the problems with the values of the corresponding symbols.
178+
179+
Note that all of the problems passed to the function returned by `setu` must satisfy
180+
`is_timeseries(prob) === NotTimeseries()`.
181+
182+
Note that if any subset of the `n` systems share common symbols (among those passed to
183+
`BatchedInterface`) then all of the corresponding problems in the subset will be updated
184+
with the values of the common symbols.
185+
186+
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
187+
"""
188+
function setu(bi::BatchedInterface)
189+
numprobs = length(bi.system_to_symbol_subset)
190+
probnames = [Symbol(:prob, i) for i in 1:numprobs]
191+
192+
fnbody = quote end
193+
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
194+
probname = probnames[sys_idx]
195+
for (idx_in_subset, idx_in_union) in enumerate(subset)
196+
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
197+
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
198+
setter = isstate ? set_state! : set_parameter!
199+
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
200+
end
201+
# also run hook
202+
if !all(bi.system_to_isstate[sys_idx])
203+
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
204+
for idx_in_subset in 1:length(subset)
205+
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
206+
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
207+
end
208+
end
209+
push!(fnbody.args, :(return vals))
210+
fnexpr = Expr(
211+
:function,
212+
Expr(:tuple, probnames..., :vals),
213+
fnbody
214+
)
215+
return @RuntimeGeneratedFunction(fnexpr)
216+
end

test/batched_interface_test.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using SymbolicIndexingInterface
2+
3+
syss = [
4+
SymbolCache([:x, :y, :z], [:a, :b, :c], :t),
5+
SymbolCache([:z, :w, :v], [:c, :e, :f]),
6+
SymbolCache([:w, :x, :u], [:e, :a, :f])
7+
]
8+
syms = [
9+
[:x, :z, :b, :c],
10+
[:z, :w, :c, :f],
11+
[:w, :x, :e, :a]
12+
]
13+
probs = [
14+
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
15+
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]),
16+
ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9])
17+
]
18+
19+
@test_throws ErrorException BatchedInterface((syss[1], [:x, 3]))
20+
@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)]))
21+
@test_throws ErrorException BatchedInterface((syss[1], [:t]))
22+
23+
bi = BatchedInterface(zip(syss, syms)...)
24+
@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a]
25+
@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
26+
[8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing]
27+
@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
28+
Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0]
29+
@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3]
30+
31+
getter = getu(bi)
32+
@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]
33+
buf = zeros(8)
34+
@inferred getter(buf, probs...)
35+
@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]
36+
37+
setter! = setu(bi)
38+
buf .*= 100
39+
setter!(probs..., buf)
40+
41+
@test state_values(probs[1]) == [100.0, 2.0, 300.0]
42+
# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1])
43+
@test parameter_values(probs[1]) == [0.1, 20.0, 30.0]
44+
@test state_values(probs[2]) == [300.0, 500.0, 6.0]
45+
# Similarly for :e
46+
@test parameter_values(probs[2]) == [30.0, 0.5, 60.0]
47+
@test state_values(probs[3]) == [500.0, 100.0, 9.0]
48+
# Similarly for :f
49+
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9]

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,6 @@ end
2929
@safetestset "ProblemState test" begin
3030
@time include("problem_state_test.jl")
3131
end
32+
@safetestset "BatchedInterface test" begin
33+
@time include("batched_interface_test.jl")
34+
end

0 commit comments

Comments
 (0)