Skip to content

Commit 2cb9486

Browse files
committed
disallow SymblolicStateIndex{?, Int} for p indexing
when using int to specify the symbol it musst be SymbolicParameterIndex
1 parent 7029243 commit 2cb9486

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

src/symbolicindexing.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ SII.all_symbols(nw::Network) = vcat(SII.all_variable_symbols(nw), SII.parameter_
220220
####
221221
#### variable indexing
222222
####
223+
const POTENTIAL_SCALAR_SIDX = Union{SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}}
223224
function SII.is_variable(nw::Network, sni)
224225
if _hascolon(sni)
225226
SII.is_variable(nw, _resolve_colon(nw,sni))
@@ -230,7 +231,7 @@ function SII.is_variable(nw::Network, sni)
230231
end
231232
end
232233
_is_variable(nw::Network, sni) = false
233-
function _is_variable(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}})
234+
function _is_variable(nw::Network, sni::POTENTIAL_SCALAR_SIDX)
234235
cf = getcomp(nw, sni)
235236
return subsym_has_idx(sni.subidx, sym(cf))
236237
end
@@ -244,7 +245,7 @@ function SII.variable_index(nw::Network, sni)
244245
_variable_index(nw, sni)
245246
end
246247
end
247-
function _variable_index(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}})
248+
function _variable_index(nw::Network, sni::POTENTIAL_SCALAR_SIDX)
248249
cf = getcomp(nw, sni)
249250
range = getcomprange(nw, sni)
250251
range[subsym_to_idx(sni.subidx, sym(cf))]
@@ -265,6 +266,11 @@ end
265266
####
266267
#### parameter indexing
267268
####
269+
# when using an number instead of symbol only PIndex is valid
270+
const POTENTIAL_SCALAR_PIDX = Union{
271+
SymbolicParameterIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}},
272+
SymbolicIndex{<:Union{Symbol,Int},Symbol}
273+
}
268274
function SII.is_parameter(nw::Network, sni)
269275
if _hascolon(sni)
270276
SII.is_parameter(nw, _resolve_colon(nw,sni))
@@ -275,8 +281,7 @@ function SII.is_parameter(nw::Network, sni)
275281
end
276282
end
277283
_is_parameter(nw::Network, sni) = false
278-
function _is_parameter(nw::Network,
279-
sni::SymbolicIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}})
284+
function _is_parameter(nw::Network, sni::POTENTIAL_SCALAR_PIDX)
280285
cf = getcomp(nw, sni)
281286
return subsym_has_idx(sni.subidx, psym(cf))
282287
end
@@ -290,15 +295,14 @@ function SII.parameter_index(nw::Network, sni)
290295
_parameter_index(nw, sni)
291296
end
292297
end
293-
function _parameter_index(nw::Network,
294-
sni::SymbolicIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}})
298+
function _parameter_index(nw::Network, sni::POTENTIAL_SCALAR_PIDX)
295299
cf = getcomp(nw, sni)
296300
range = getcompprange(nw, sni)
297301
range[subsym_to_idx(sni.subidx, psym(cf))]
298302
end
299303

300304
function SII.parameter_symbols(nw::Network)
301-
syms = Vector{SymbolicIndex{Int,Symbol}}(undef, pdim(nw))
305+
syms = Vector{SymbolicParameterIndex{Int,Symbol}}(undef, pdim(nw))
302306
for (ci, cf) in pairs(nw.im.vertexm)
303307
syms[nw.im.v_para[ci]] .= VPIndex.(ci, psym(cf))
304308
end
@@ -388,7 +392,7 @@ function SII.is_observed(nw::Network, sni)
388392
end
389393
end
390394
_is_observed(nw::Network, _) = false
391-
function _is_observed(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}})
395+
function _is_observed(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},Symbol})
392396
cf = getcomp(nw, sni)
393397
return sni.subidx obssym_all(cf)
394398
end

test/symbolicindexing_test.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,12 @@ _uflat = copy(sol(t))
103103
_pflat = copy(sol.prob.p)
104104
s = NWState(nw, _uflat, _pflat)
105105

106-
SII.getu(s, EIndex(1,:e_dst))(s)
107-
SII.getp(s, VPIndex(1,:M))(s)
106+
@test SII.getu(s, EIndex(1,:e_dst))(s) == uflat(s)[7]
107+
@test SII.getp(s, VPIndex(1,:M))(s) ==pflat(s)[1]
108+
@test SII.getp(s, VIndex(1,:M))(s) ==pflat(s)[1]
108109

109-
SII.is_variable(nw, EIndex(1,:e_dst))
110-
SII.variable_index(nw, EIndex(1,:e_dst))
110+
@test SII.is_variable(nw, EIndex(1,:e_dst))
111+
@test SII.variable_index(nw, EIndex(1,:e_dst)) == 7
111112

112113
@test map(idx->s[idx], SII.variable_symbols(nw)) == _uflat
113114
@test map(idx->s[idx], SII.parameter_symbols(nw)) == _pflat
@@ -185,10 +186,10 @@ for et in [VIndex, EIndex, VPIndex, EPIndex]
185186
repr.(et(1,))
186187
end
187188

188-
@test s[[VIndex(1,1), VPIndex(1,2)]] == [s[VIndex(1,1)], s[VPIndex(1,2)]]
189-
@test s[(VIndex(1,1), VPIndex(1,2))] == (s[VIndex(1,1)], s[VPIndex(1,2)])
190-
@test s[[VIndex(1,1), VIndex(1,2)]] == [s[VIndex(1,1)], s[VIndex(1,2)]]
191-
@test s[(VIndex(1,1), VIndex(1,2))] == (s[VIndex(1,1)], s[VIndex(1,2)])
189+
@test s[[VIndex(1,1), VPIndex(1,2)]] == [s[VIndex(1,1)], s[VPIndex(1,2)]] == [uflat(s)[1], pflat(s)[2]]
190+
@test s[(VIndex(1,1), VPIndex(1,2))] == (s[VIndex(1,1)], s[VPIndex(1,2)]) == (uflat(s)[1], pflat(s)[2])
191+
@test s[[VIndex(1,1), VIndex(1,2)]] == [s[VIndex(1,1)], s[VIndex(1,2)]] == [uflat(s)[1], uflat(s)[2]]
192+
@test s[(VIndex(1,1), VIndex(1,2))] == (s[VIndex(1,1)], s[VIndex(1,2)]) == (uflat(s)[1], uflat(s)[2])
192193

193194
@test s[VIndex(:,1)] == s[VIndex(1:4,1)]
194195
@test_broken s[EIndex(:,1)] == s[EIndex(1:6,1)]

0 commit comments

Comments
 (0)