diff --git a/docs/src/api.md b/docs/src/api.md index 21ba613..cf767b9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -112,6 +112,7 @@ with_updated_parameter_timeseries_values ```@docs BatchedInterface associated_systems +setsym_oop ``` ## Container objects diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index aa89975..7b2f3cc 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -38,7 +38,7 @@ include("parameter_indexing.jl") export getu, setu include("state_indexing.jl") -export BatchedInterface, associated_systems +export BatchedInterface, setsym_oop, associated_systems include("batched_interface.jl") export ProblemState diff --git a/src/batched_interface.jl b/src/batched_interface.jl index 25a06b0..c8d02d8 100644 --- a/src/batched_interface.jl +++ b/src/batched_interface.jl @@ -21,7 +21,7 @@ See [`getu`](@ref) and [`setu`](@ref) for further details. See also: [`associated_systems`](@ref). """ -struct BatchedInterface{S <: AbstractVector, I, T} +struct BatchedInterface{S <: AbstractVector, I, T, P} "Order of symbols in the union." symbol_order::S "Index of the index provider each symbol in the union is associated with." @@ -36,6 +36,8 @@ struct BatchedInterface{S <: AbstractVector, I, T} system_to_symbol_indexes::Vector{Vector{T}} "Map from index provider to whether each of its symbols is a state in the index provider." system_to_isstate::Vector{BitVector} + "Index providers, in order" + index_providers::Vector{P} end function BatchedInterface(syssyms::Tuple...) @@ -46,11 +48,18 @@ function BatchedInterface(syssyms::Tuple...) system_to_symbol_subset = Vector{Int}[] system_to_symbol_indexes = [] system_to_isstate = BitVector[] + index_providers = [] for (i, (sys, syms)) in enumerate(syssyms) symbol_subset = Int[] symbol_indexes = [] system_isstate = BitVector() allsyms = [] + root_indp = sys + while applicable(symbolic_container, root_indp) && + (sc = symbolic_container(root_indp)) != root_indp + root_indp = sc + end + push!(index_providers, root_indp) for sym in syms if symbolic_type(sym) === NotSymbolic() error("Only symbolic variables allowed in BatchedInterface.") @@ -89,9 +98,10 @@ function BatchedInterface(syssyms::Tuple...) system_to_symbol_indexes = identity.(system_to_symbol_indexes) return BatchedInterface{typeof(symbol_order), typeof(associated_indexes), - eltype(eltype(system_to_symbol_indexes))}( + eltype(eltype(system_to_symbol_indexes)), eltype(index_providers)}( symbol_order, associated_systems, associated_indexes, isstate, - system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate) + system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate, + identity.(index_providers)) end variable_symbols(bi::BatchedInterface) = bi.symbol_order @@ -268,3 +278,102 @@ function setu(bi::BatchedInterface) setter! end end + +""" + setsym_oop(bi::BatchedInterface) + +Given a [`BatchedInterface`](@ref) composed from `n` index providers (and corresponding +symbols), return a function which takes `n` corresponding value providers and an array of +values, and returns an `n`-tuple where each element is a 2-tuple consisting of the updated +state values and parameter values of the corresponding value provider. Requires that the +value provider implement [`state_values`](@ref), [`parameter_values`](@ref). The updates are +performed out-of-place using [`remake_buffer`](@ref). + +Note that all of the value providers passed to the returned function must satisfy +`is_timeseries(prob) === NotTimeseries()`. + +Note that if any subset of the `n` index providers share common symbols (among those passed +to `BatchedInterface`) then all of the corresponding value providers in the subset will be +updated with the values of the common symbols. + +See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref). +""" +function setsym_oop(bi::BatchedInterface) + numprobs = length(bi.system_to_symbol_subset) + probnames = [Symbol(:prob, i) for i in 1:numprobs] + arg = :vals + full_update = Expr(:block) + + function get_update_expr(prob::Symbol, sys_i::Int) + union_idxs = bi.system_to_symbol_subset[sys_i] + indp_idxs = bi.system_to_symbol_indexes[sys_i] + isstate = bi.system_to_isstate[sys_i] + indp = bi.index_providers[sys_i] + curexpr = Expr(:block) + + statessym = Symbol(:states_, sys_i) + if all(.!isstate) + push!(curexpr.args, :($statessym = $state_values($prob))) + else + state_idxssym = Symbol(:state_idxs_, sys_i) + state_idxs = indp_idxs[isstate] + state_valssym = Symbol(:state_vals_, sys_i) + vals_idxs = union_idxs[isstate] + push!(curexpr.args, :($state_idxssym = $state_idxs)) + push!(curexpr.args, :($state_valssym = $view($arg, $vals_idxs))) + push!(curexpr.args, + :($statessym = $remake_buffer( + syss[$sys_i], $state_values($prob), $state_idxssym, $state_valssym))) + end + + paramssym = Symbol(:params_, sys_i) + if all(isstate) + push!(curexpr.args, :($paramssym = $parameter_values($prob))) + else + param_idxssym = Symbol(:param_idxs_, sys_i) + param_idxs = indp_idxs[.!isstate] + param_valssym = Symbol(:param_vals, sys_i) + vals_idxs = union_idxs[.!isstate] + push!(curexpr.args, :($param_idxssym = $param_idxs)) + push!(curexpr.args, :($param_valssym = $view($arg, $vals_idxs))) + push!(curexpr.args, + :($paramssym = $remake_buffer( + syss[$sys_i], $parameter_values($prob), $param_idxssym, $param_valssym))) + end + + return curexpr, statessym, paramssym + end + + full_update_expr = Expr(:block) + full_update_retval = Expr(:tuple) + partial_update_expr = Expr(:block) + cur_partial_update_expr = partial_update_expr + for i in 1:numprobs + update_expr, statesym, paramsym = get_update_expr(probnames[i], i) + push!(full_update_expr.args, update_expr) + push!(full_update_retval.args, Expr(:tuple, statesym, paramsym)) + + cur_ifexpr = Expr(i == 1 ? :if : :elseif, :(idx == $i)) + update_expr, statesym, paramsym = get_update_expr(:prob, i) + push!(update_expr.args, :(return ($statesym, $paramsym))) + push!(cur_ifexpr.args, update_expr) + push!(cur_partial_update_expr.args, cur_ifexpr) + cur_partial_update_expr = cur_ifexpr + end + push!(full_update_expr.args, :(return $full_update_retval)) + push!(cur_partial_update_expr.args, :(error("Invalid problem index $idx"))) + + full_update_fnexpr = Expr( + :function, Expr(:tuple, :syss, probnames..., arg), full_update_expr) + partial_update_fnexpr = Expr( + :function, Expr(:tuple, :syss, :prob, :idx, arg), partial_update_expr) + + return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr), + partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr), + syss = Tuple(bi.index_providers) + + setter(args...) = full_update(syss, args...) + setter(prob, idx::Int, vals::AbstractVector) = partial_update(syss, prob, idx, vals) + setter + end +end diff --git a/test/batched_interface_test.jl b/test/batched_interface_test.jl index 3e622cd..13dd415 100644 --- a/test/batched_interface_test.jl +++ b/test/batched_interface_test.jl @@ -54,3 +54,31 @@ setter!(probs[1], 1, buf) @test parameter_values(probs[1]) == [0.1, 0.2, 0.3] @test_throws ErrorException setter!(probs[1], 4, buf) + +setter!(probs..., buf) + +setter = setsym_oop(bi) + +buf .*= 100 +vals = setter(probs..., buf) +@test length(vals) == length(probs) +@test vals[1][1] == [100.0, 2.0, 300.0] +@test vals[1][2] == [0.1, 20.0, 30.0] +@test vals[2][1] == [300.0, 500.0, 6.0] +@test vals[2][2] == [30.0, 0.5, 60.0] +@test vals[3][1] == [500.0, 100.0, 9.0] +@test vals[3][2] == [70.0, 80.0, 0.9] + +# out-of-place +for i in 1:3 + @test vals[i][1] != state_values(probs[i]) + @test vals[i][2] != parameter_values(probs[i]) +end + +buf ./= 10 +vals = setter(probs[1], 1, buf) +@test length(vals) == 2 +@test vals[1] == [10.0, 2.0, 30.0] +@test vals[2] == [0.1, 2.0, 3.0] + +@test_throws ErrorException setter(probs[1], 4, buf) diff --git a/test/downstream/batchedinterface_arrayvars.jl b/test/downstream/batchedinterface_arrayvars.jl index 83c003d..eaba28c 100644 --- a/test/downstream/batchedinterface_arrayvars.jl +++ b/test/downstream/batchedinterface_arrayvars.jl @@ -40,3 +40,36 @@ buf ./= 10 setter!(probs[1], 1, buf) @test state_values(probs[1]) == [1.0, 2.0, 3.0] + +@variables a b[1:2] c + +syss = [ + SymbolCache([x..., y], [a, b...]), + SymbolCache([x[1], y, z], [a, b..., c]) +] +syms = [ + [x, y, a, b...], + [x[1], y, b[2], c] +] +probs = [ + ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]), + ProblemState(; u = [4.0, 5.0, 6.0], p = [0.1, 0.4, 0.5, 0.6]) +] + +bi = BatchedInterface(zip(syss, syms)...) + +buf = getu(bi)(probs...) +buf .*= 100 +setter = setsym_oop(bi) +vals = setter(probs..., buf) +@test length(vals) == length(probs) +@test vals[1][1] == [100.0, 200.0, 300.0] +@test vals[1][2] == [10.0, 20.0, 30.0] +@test vals[2][1] == [100.0, 300.0, 6.0] +@test vals[2][2] == [0.1, 0.4, 30.0, 60.0] + +buf ./= 10 +vals = setter(probs[1], 1, buf) +@test length(vals) == 2 +@test vals[1] == [10.0, 20.0, 30.0] +@test vals[2] == [1.0, 2.0, 3.0]