Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ with_updated_parameter_timeseries_values
```@docs
BatchedInterface
associated_systems
setsym_oop
```

## Container objects
Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ include("parameter_indexing.jl")
export getu, setu
include("state_indexing.jl")

export BatchedInterface, associated_systems
export BatchedInterface, setsym_oop, associated_systems
include("batched_interface.jl")

export ProblemState
Expand Down
115 changes: 112 additions & 3 deletions src/batched_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ See [`getu`](@ref) and [`setu`](@ref) for further details.

See also: [`associated_systems`](@ref).
"""
struct BatchedInterface{S <: AbstractVector, I, T}
struct BatchedInterface{S <: AbstractVector, I, T, P}
"Order of symbols in the union."
symbol_order::S
"Index of the index provider each symbol in the union is associated with."
Expand All @@ -36,6 +36,8 @@ struct BatchedInterface{S <: AbstractVector, I, T}
system_to_symbol_indexes::Vector{Vector{T}}
"Map from index provider to whether each of its symbols is a state in the index provider."
system_to_isstate::Vector{BitVector}
"Index providers, in order"
index_providers::Vector{P}
end

function BatchedInterface(syssyms::Tuple...)
Expand All @@ -46,11 +48,18 @@ function BatchedInterface(syssyms::Tuple...)
system_to_symbol_subset = Vector{Int}[]
system_to_symbol_indexes = []
system_to_isstate = BitVector[]
index_providers = []
for (i, (sys, syms)) in enumerate(syssyms)
symbol_subset = Int[]
symbol_indexes = []
system_isstate = BitVector()
allsyms = []
root_indp = sys
while applicable(symbolic_container, root_indp) &&
(sc = symbolic_container(root_indp)) != root_indp
root_indp = sc
end
push!(index_providers, root_indp)
for sym in syms
if symbolic_type(sym) === NotSymbolic()
error("Only symbolic variables allowed in BatchedInterface.")
Expand Down Expand Up @@ -89,9 +98,10 @@ function BatchedInterface(syssyms::Tuple...)
system_to_symbol_indexes = identity.(system_to_symbol_indexes)

return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),
eltype(eltype(system_to_symbol_indexes))}(
eltype(eltype(system_to_symbol_indexes)), eltype(index_providers)}(
symbol_order, associated_systems, associated_indexes, isstate,
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate,
identity.(index_providers))
end

variable_symbols(bi::BatchedInterface) = bi.symbol_order
Expand Down Expand Up @@ -268,3 +278,102 @@ function setu(bi::BatchedInterface)
setter!
end
end

"""
setsym_oop(bi::BatchedInterface)

Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding
symbols), return a function which takes `n` corresponding value providers and an array of
values, and returns an `n`-tuple where each element is a 2-tuple consisting of the updated
state values and parameter values of the corresponding value provider. Requires that the
value provider implement [`state_values`](@ref), [`parameter_values`](@ref). The updates are
performed out-of-place using [`remake_buffer`](@ref).

Note that all of the value providers passed to the returned function must satisfy
`is_timeseries(prob) === NotTimeseries()`.

Note that if any subset of the `n` index providers share common symbols (among those passed
to `BatchedInterface`) then all of the corresponding value providers in the subset will be
updated with the values of the common symbols.

See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
"""
function setsym_oop(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]
arg = :vals
full_update = Expr(:block)

function get_update_expr(prob::Symbol, sys_i::Int)
union_idxs = bi.system_to_symbol_subset[sys_i]
indp_idxs = bi.system_to_symbol_indexes[sys_i]
isstate = bi.system_to_isstate[sys_i]
indp = bi.index_providers[sys_i]
curexpr = Expr(:block)

statessym = Symbol(:states_, sys_i)
if all(.!isstate)
push!(curexpr.args, :($statessym = $state_values($prob)))
else
state_idxssym = Symbol(:state_idxs_, sys_i)
state_idxs = indp_idxs[isstate]
state_valssym = Symbol(:state_vals_, sys_i)
vals_idxs = union_idxs[isstate]
push!(curexpr.args, :($state_idxssym = $state_idxs))
push!(curexpr.args, :($state_valssym = $view($arg, $vals_idxs)))
push!(curexpr.args,
:($statessym = $remake_buffer(
syss[$sys_i], $state_values($prob), $state_idxssym, $state_valssym)))
end

paramssym = Symbol(:params_, sys_i)
if all(isstate)
push!(curexpr.args, :($paramssym = $parameter_values($prob)))
else
param_idxssym = Symbol(:param_idxs_, sys_i)
param_idxs = indp_idxs[.!isstate]
param_valssym = Symbol(:param_vals, sys_i)
vals_idxs = union_idxs[.!isstate]
push!(curexpr.args, :($param_idxssym = $param_idxs))
push!(curexpr.args, :($param_valssym = $view($arg, $vals_idxs)))
push!(curexpr.args,
:($paramssym = $remake_buffer(
syss[$sys_i], $parameter_values($prob), $param_idxssym, $param_valssym)))
end

return curexpr, statessym, paramssym
end

full_update_expr = Expr(:block)
full_update_retval = Expr(:tuple)
partial_update_expr = Expr(:block)
cur_partial_update_expr = partial_update_expr
for i in 1:numprobs
update_expr, statesym, paramsym = get_update_expr(probnames[i], i)
push!(full_update_expr.args, update_expr)
push!(full_update_retval.args, Expr(:tuple, statesym, paramsym))

cur_ifexpr = Expr(i == 1 ? :if : :elseif, :(idx == $i))
update_expr, statesym, paramsym = get_update_expr(:prob, i)
push!(update_expr.args, :(return ($statesym, $paramsym)))
push!(cur_ifexpr.args, update_expr)
push!(cur_partial_update_expr.args, cur_ifexpr)
cur_partial_update_expr = cur_ifexpr
end
push!(full_update_expr.args, :(return $full_update_retval))
push!(cur_partial_update_expr.args, :(error("Invalid problem index $idx")))

full_update_fnexpr = Expr(
:function, Expr(:tuple, :syss, probnames..., arg), full_update_expr)
partial_update_fnexpr = Expr(
:function, Expr(:tuple, :syss, :prob, :idx, arg), partial_update_expr)

return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr),
syss = Tuple(bi.index_providers)

setter(args...) = full_update(syss, args...)
setter(prob, idx::Int, vals::AbstractVector) = partial_update(syss, prob, idx, vals)
setter
end
end
28 changes: 28 additions & 0 deletions test/batched_interface_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,31 @@ setter!(probs[1], 1, buf)
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]

@test_throws ErrorException setter!(probs[1], 4, buf)

setter!(probs..., buf)

setter = setsym_oop(bi)

buf .*= 100
vals = setter(probs..., buf)
@test length(vals) == length(probs)
@test vals[1][1] == [100.0, 2.0, 300.0]
@test vals[1][2] == [0.1, 20.0, 30.0]
@test vals[2][1] == [300.0, 500.0, 6.0]
@test vals[2][2] == [30.0, 0.5, 60.0]
@test vals[3][1] == [500.0, 100.0, 9.0]
@test vals[3][2] == [70.0, 80.0, 0.9]

# out-of-place
for i in 1:3
@test vals[i][1] != state_values(probs[i])
@test vals[i][2] != parameter_values(probs[i])
end

buf ./= 10
vals = setter(probs[1], 1, buf)
@test length(vals) == 2
@test vals[1] == [10.0, 2.0, 30.0]
@test vals[2] == [0.1, 2.0, 3.0]

@test_throws ErrorException setter(probs[1], 4, buf)
33 changes: 33 additions & 0 deletions test/downstream/batchedinterface_arrayvars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,36 @@ buf ./= 10

setter!(probs[1], 1, buf)
@test state_values(probs[1]) == [1.0, 2.0, 3.0]

@variables a b[1:2] c

syss = [
SymbolCache([x..., y], [a, b...]),
SymbolCache([x[1], y, z], [a, b..., c])
]
syms = [
[x, y, a, b...],
[x[1], y, b[2], c]
]
probs = [
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.1, 0.4, 0.5, 0.6])
]

bi = BatchedInterface(zip(syss, syms)...)

buf = getu(bi)(probs...)
buf .*= 100
setter = setsym_oop(bi)
vals = setter(probs..., buf)
@test length(vals) == length(probs)
@test vals[1][1] == [100.0, 200.0, 300.0]
@test vals[1][2] == [10.0, 20.0, 30.0]
@test vals[2][1] == [100.0, 300.0, 6.0]
@test vals[2][2] == [0.1, 0.4, 30.0, 60.0]

buf ./= 10
vals = setter(probs[1], 1, buf)
@test length(vals) == 2
@test vals[1] == [10.0, 20.0, 30.0]
@test vals[2] == [1.0, 2.0, 3.0]
Loading