Skip to content

Commit 2107adc

Browse files
feat: add all_solvable_symbols and all_symbols
1 parent 0e67d45 commit 2107adc

File tree

7 files changed

+58
-3
lines changed

7 files changed

+58
-3
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ is_observed
1414
observed
1515
is_time_dependent
1616
constant_structure
17+
all_solvable_symbols
18+
all_symbols
1719
parameter_values
1820
getp
1921
setp

docs/src/tutorial.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ struct ExampleSolution
4141
state_index::Dict{Symbol,Int}
4242
parameter_index::Dict{Symbol,Int}
4343
independent_variable::Union{Symbol,Nothing}
44+
# mapping from observed variable to Expr to calculate its value
45+
observed::Dict{Symbol,Expr}
4446
u::Vector{Vector{Float64}}
4547
p::Vector{Float64}
4648
t::Vector{Float64}
@@ -86,9 +88,9 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSolu
8688
sys.independent_variable === nothing ? [] : [sys.independent_variable]
8789
end
8890

89-
# this types accepts `Expr` for observed expressions involving state/parameter
91+
# this type accepts `Expr` for observed expressions involving state/parameter/observed
9092
# variables
91-
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr
93+
SymbolicIndexingInterface.is_observed(sys::ExampleSolution, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)
9294

9395
function SymbolicIndexingInterface.observed(sys::ExampleSolution, sym::Expr)
9496
if is_time_dependent(sys)
@@ -109,6 +111,21 @@ function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSolution)
109111
end
110112

111113
SymbolicIndexingInterface.constant_structure(::ExampleSolution) = true
114+
115+
function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSolution)
116+
return vcat(
117+
collect(keys(sys.state_index)),
118+
collect(keys(sys.observed)),
119+
)
120+
end
121+
122+
function SymbolicIndexingInterface.all_symbols(sys::ExampleSolution)
123+
return vcat(
124+
all_solvable_symbols(sys),
125+
collect(keys(sys.parameter_index)),
126+
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
127+
)
128+
end
112129
```
113130

114131
Note that the method definitions are all assuming `constant_structure(p) == true`.

src/SymbolicIndexingInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ include("trait.jl")
55

66
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
77
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
8-
observed, is_time_dependent, constant_structure, symbolic_container
8+
observed, is_time_dependent, constant_structure, symbolic_container, all_solvable_symbols,
9+
all_symbols
910
include("interface.jl")
1011

1112
export SymbolCache

src/interface.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,19 @@ Check if `sys` has a constant structure. Constant structure systems do not chang
110110
number of variables or parameters over time.
111111
"""
112112
constant_structure(sys) = constant_structure(symbolic_container(sys))
113+
114+
"""
115+
all_solvable_symbols(sys)
116+
117+
Return an array of all symbols in the system that can be solved for. This includes
118+
observed variables, but not parameters or independent variables.
119+
"""
120+
all_solvable_symbols(sys) = all_solvable_symbols(symbolic_container(sys))
121+
122+
"""
123+
all_symbols(sys)
124+
125+
Return an array of all symbols in the system. This includes parameters and independent
126+
variables.
127+
"""
128+
all_symbols(sys) = all_symbols(symbolic_container(sys))

src/symbol_cache.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ function is_time_dependent(sc::SymbolCache)
7171
end
7272
end
7373
constant_structure(::SymbolCache) = true
74+
all_solvable_symbols(sc::SymbolCache) = variable_symbols(sc)
75+
all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))
7476

7577
function Base.copy(sc::SymbolCache)
7678
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),

test/example_test.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ function SymbolicIndexingInterface.observed(sys::SystemMockup, sym, states = not
5353
end
5454
SymbolicIndexingInterface.is_time_dependent(sys::SystemMockup) = isequal(sys.indepvar, :t)
5555
SymbolicIndexingInterface.constant_structure(sys::SystemMockup) = sys.static
56+
SymbolicIndexingInterface.all_solvable_symbols(sys::SystemMockup) = sys.vars
57+
function SymbolicIndexingInterface.all_symbols(sys::SystemMockup)
58+
vcat(sys.vars, sys.params, independent_variable_symbols(sys))
59+
end
5660

5761
sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
5862

@@ -79,6 +83,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
7983
@test variable_symbols(sys) == [:x, :y, :z]
8084
@test parameter_symbols(sys) == [:a, :b, :c]
8185
@test independent_variable_symbols(sys) == [:t]
86+
@test all_solvable_symbols(sys) == [:x, :y, :z]
87+
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]
8288

8389
sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
8490

@@ -93,6 +99,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
9399
@test variable_symbols(sys) == [:x, :y, :z]
94100
@test parameter_symbols(sys) == [:a, :b, :c]
95101
@test independent_variable_symbols(sys) == []
102+
@test all_solvable_symbols(sys) == [:x, :y, :z]
103+
@test sort(all_symbols(sys)) == [:a, :b, :c, :x, :y, :z]
96104

97105
sys = SystemMockup(false, [:x, :y, :z], [:a, :b, :c], :t)
98106
@test !constant_structure(sys)
@@ -115,3 +123,5 @@ end
115123
@test variable_symbols(sys, 3) == [:x, :y, :z]
116124
@test parameter_symbols(sys) == [:a, :b, :c]
117125
@test independent_variable_symbols(sys) == [:t]
126+
@test all_solvable_symbols(sys) == [:x, :y, :z]
127+
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]

test/symbol_cache_test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
1414
@test variable_symbols(sc) == [:x, :y, :z]
1515
@test parameter_symbols(sc) == [:a, :b]
1616
@test independent_variable_symbols(sc) == [:t]
17+
@test all_solvable_symbols(sc) == [:x, :y, :z]
18+
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]
1719

1820
sc = SymbolCache([:x, :y], [:a, :b])
1921
@test !is_time_dependent(sc)
22+
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
2023
# make sure the constructor works
2124
@test_nowarn SymbolCache([:x, :y])
2225

@@ -30,12 +33,16 @@ sc = SymbolCache()
3033
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t]))
3134
@test independent_variable_symbols(sc) == []
3235
@test !is_time_dependent(sc)
36+
@test all_solvable_symbols(sc) == []
37+
@test all_symbols(sc) == []
3338

3439
sc = SymbolCache(nothing, nothing, :t)
3540
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
3641
@test is_independent_variable(sc, :t)
3742
@test independent_variable_symbols(sc) == [:t]
3843
@test is_time_dependent(sc)
44+
@test all_solvable_symbols(sc) == []
45+
@test all_symbols(sc) == [:t]
3946

4047
sc2 = copy(sc)
4148
@test sc.variables == sc2.variables

0 commit comments

Comments
 (0)