Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/problem_state.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
"""
struct ProblemState
function ProblemState(; u = nothing, p = nothing, t = nothing)
function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing)

A value provider struct which can be used as an argument to the function returned by
[`getsym`](@ref) or [`setsym`](@ref). It stores the state vector, parameter object and
current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref),
[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained
objects.

A history function may be provided using the `h` keyword, which will be returned with
[`get_history_function`](@ref).
"""
struct ProblemState{U, P, T}
struct ProblemState{U, P, T, H}
u::U
p::P
t::T
h::H
end

ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t)
function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing)
ProblemState(u, p, t, h)
end

state_values(ps::ProblemState) = ps.u
parameter_values(ps::ProblemState) = ps.p
current_time(ps::ProblemState) = ps.t
set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx)
set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx)
get_history_function(ps::ProblemState) = ps.h
5 changes: 4 additions & 1 deletion test/problem_state_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ using SymbolicIndexingInterface
using Test

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5)
prob = ProblemState(;
u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5, h = Returns(ones(3)))

for (i, sym) in enumerate(variable_symbols(sys))
@test getsym(sys, sym)(prob) == prob.u[i]
Expand All @@ -13,3 +14,5 @@ end
@test getsym(sys, :t)(prob) == prob.t

@test getsym(sys, :(x + a + t))(prob) == 1.6

@test get_history_function(prob) !== nothing
3 changes: 3 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ getter = getsym(sys, :(x + y))
@test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11]
@test getter(fs, 1) ≈ 2.8

pstate = ProblemState(; u = u0, p = p, t = ts[1], h = t -> t .* ones(length(u0)))
@test getter(pstate) ≈ 2.8

struct TupleObservedWrapper{S}
sys::S
end
Expand Down
Loading