Skip to content

Commit 535bd8d

Browse files
Merge pull request #109 from SciML/as/prob-state-history
feat: add optional history function to `ProblemState`
2 parents 2a247d3 + e440ff7 commit 535bd8d

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

src/problem_state.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11
"""
22
struct ProblemState
3-
function ProblemState(; u = nothing, p = nothing, t = nothing)
3+
function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing)
44
55
A value provider struct which can be used as an argument to the function returned by
66
[`getsym`](@ref) or [`setsym`](@ref). It stores the state vector, parameter object and
77
current time, and forwards calls to [`state_values`](@ref), [`parameter_values`](@ref),
88
[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained
99
objects.
10+
11+
A history function may be provided using the `h` keyword, which will be returned with
12+
[`get_history_function`](@ref).
1013
"""
11-
struct ProblemState{U, P, T}
14+
struct ProblemState{U, P, T, H}
1215
u::U
1316
p::P
1417
t::T
18+
h::H
1519
end
1620

17-
ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t)
21+
function ProblemState(; u = nothing, p = nothing, t = nothing, h = nothing)
22+
ProblemState(u, p, t, h)
23+
end
1824

1925
state_values(ps::ProblemState) = ps.u
2026
parameter_values(ps::ProblemState) = ps.p
2127
current_time(ps::ProblemState) = ps.t
2228
set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx)
2329
set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx)
30+
get_history_function(ps::ProblemState) = ps.h

test/problem_state_test.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using SymbolicIndexingInterface
22
using Test
33

44
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
5-
prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5)
5+
prob = ProblemState(;
6+
u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5, h = Returns(ones(3)))
67

78
for (i, sym) in enumerate(variable_symbols(sys))
89
@test getsym(sys, sym)(prob) == prob.u[i]
@@ -13,3 +14,5 @@ end
1314
@test getsym(sys, :t)(prob) == prob.t
1415

1516
@test getsym(sys, :(x + a + t))(prob) == 1.6
17+
18+
@test get_history_function(prob) !== nothing

test/state_indexing_test.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ getter = getsym(sys, :(x + y))
347347
@test getter(fs) [3.0i + 2(ts[i] - 0.1) for i in 1:11]
348348
@test getter(fs, 1) 2.8
349349

350+
pstate = ProblemState(; u = u0, p = p, t = ts[1], h = t -> t .* ones(length(u0)))
351+
@test getter(pstate) 2.8
352+
350353
struct TupleObservedWrapper{S}
351354
sys::S
352355
end

0 commit comments

Comments
 (0)