Skip to content

Commit b4e781c

Browse files
refactor: rework SymbolCache
1 parent a110ee4 commit b4e781c

File tree

2 files changed

+60
-16
lines changed

2 files changed

+60
-16
lines changed

src/symbol_cache.jl

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,59 @@ of having a vector of variables, parameters and independent variables. This
77
struct does not implement `observed`, and `is_observed` returns `false` for
88
all input symbols. It is considered to be time dependent if it contains
99
at least one independent variable.
10+
11+
The independent variable may be specified as a single symbolic variable instead of an
12+
array containing a single variable if the system has only one independent variable.
1013
"""
11-
struct SymbolCache{V, P, I}
12-
variables::Vector{V}
13-
parameters::Vector{P}
14-
independent_variables::Vector{I}
14+
struct SymbolCache{V<:Union{Nothing,AbstractVector}, P<:Union{Nothing,AbstractVector}, I}
15+
variables::V
16+
parameters::P
17+
independent_variables::I
1518
end
1619

17-
function SymbolCache(vars::Vector{V}, params = [], indepvars = []) where {V}
18-
return SymbolCache{V, eltype(params), eltype(indepvars)}(vars, params, indepvars)
20+
function SymbolCache(vars = nothing, params = nothing, indepvars = nothing)
21+
return SymbolCache{typeof(vars),typeof(params),typeof(indepvars)}(vars, params, indepvars)
1922
end
2023

21-
is_variable(sc::SymbolCache, sym) = any(isequal(sym), sc.variables)
22-
variable_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.variables)
23-
variable_symbols(sc::SymbolCache, i = nothing) = sc.variables
24-
is_parameter(sc::SymbolCache, sym) = any(isequal(sym), sc.parameters)
25-
parameter_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.parameters)
26-
parameter_symbols(sc::SymbolCache) = sc.parameters
27-
is_independent_variable(sc::SymbolCache, sym) = any(isequal(sym), sc.independent_variables)
28-
independent_variable_symbols(sc::SymbolCache) = sc.independent_variables
24+
is_variable(sc::SymbolCache, sym) = sc.variables !== nothing && any(isequal(sym), sc.variables)
25+
variable_index(sc::SymbolCache, sym) = sc.variables === nothing ? nothing : findfirst(isequal(sym), sc.variables)
26+
variable_symbols(sc::SymbolCache, i = nothing) = something(sc.variables, [])
27+
is_parameter(sc::SymbolCache, sym) = sc.parameters !== nothing && any(isequal(sym), sc.parameters)
28+
parameter_index(sc::SymbolCache, sym) = sc.parameters === nothing ? nothing : findfirst(isequal(sym), sc.parameters)
29+
parameter_symbols(sc::SymbolCache) = something(sc.parameters, [])
30+
function is_independent_variable(sc::SymbolCache, sym)
31+
sc.independent_variables === nothing && return false
32+
if symbolic_type(sc.independent_variables) == NotSymbolic()
33+
return any(isequal(sym), sc.independent_variables)
34+
elseif symbolic_type(sc.independent_variables) == ScalarSymbolic()
35+
return sym == sc.independent_variables
36+
else
37+
return any(isequal(sym), collect(sc.independent_variables))
38+
end
39+
end
40+
function independent_variable_symbols(sc::SymbolCache)
41+
sc.independent_variables === nothing && return []
42+
if symbolic_type(sc.independent_variables) == NotSymbolic()
43+
return sc.independent_variables
44+
elseif symbolic_type(sc.independent_variables) == ScalarSymbolic()
45+
return [sc.independent_variables]
46+
else
47+
return collect(sc.independent_variables)
48+
end
49+
end
2950
is_observed(sc::SymbolCache, sym) = false
30-
is_time_dependent(sc::SymbolCache) = !isempty(sc.independent_variables)
51+
function is_time_dependent(sc::SymbolCache)
52+
sc.independent_variables === nothing && return false
53+
if symbolic_type(sc.independent_variables) == NotSymbolic()
54+
return !isempty(sc.independent_variables)
55+
else
56+
return true
57+
end
58+
end
3159
constant_structure(::SymbolCache) = true
3260

3361
function Base.copy(sc::SymbolCache)
3462
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
3563
sc.parameters === nothing ? nothing : copy(sc.parameters),
36-
sc.independent_variables === nothing ? nothing : copy(sc.independent_variables))
64+
sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) : sc.independent_variables)
3765
end

test/symbol_cache_test.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,23 @@ sc = SymbolCache([:x, :y], [:a, :b])
1919
@test !is_time_dependent(sc)
2020
# make sure the constructor works
2121
@test_nowarn SymbolCache([:x, :y])
22+
23+
sc = SymbolCache()
24+
@test all(.!is_variable.((sc,), [:x, :y, :a, :b, :t]))
25+
@test all(variable_index.((sc,), [:x, :y, :a, :b, :t]) .== nothing)
26+
@test variable_symbols(sc) == []
27+
@test all(.!is_parameter.((sc,), [:x, :y, :a, :b, :t]))
28+
@test all(parameter_index.((sc,), [:x, :y, :a, :b, :t]) .== nothing)
29+
@test parameter_symbols(sc) == []
30+
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b, :t]))
2231
@test independent_variable_symbols(sc) == []
32+
@test !is_time_dependent(sc)
33+
34+
sc = SymbolCache(nothing, nothing, :t)
35+
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
36+
@test is_independent_variable(sc, :t)
37+
@test independent_variable_symbols(sc) == [:t]
38+
@test is_time_dependent(sc)
2339

2440
sc2 = copy(sc)
2541
@test sc.variables == sc2.variables

0 commit comments

Comments
 (0)