Skip to content

Commit cb2136b

Browse files
authored
Merge pull request #281 from JuliaDynamics/hw/edgepairs
allow indexing of edges via pairs
2 parents e10d37d + fc7b246 commit cb2136b

File tree

3 files changed

+124
-21
lines changed

3 files changed

+124
-21
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- **improved Initialization System**: Added comprehensive initialization formulas and constraints system:
88
- added `@initformula` to add explicit algebraic init equations for specific variables
99
- added `@initconstraint` to add additional constraints for the component initialization
10+
- allow access edges via Pairs, i.e. `EIndex(1=>2,:a)` references variable `:a` in edge from vertex 1 to 2. Works also with unique names of vertices like `EIndex(:a=>:b)` [#281](https://github.com/JuliaDynamics/NetworkDynamics.jl/pull/281).
1011

1112
## v0.9 Changelog
1213
### Main changes in this release

src/symbolicindexing.jl

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
idx = VIndex(comp, sub)
44
55
A symbolic index for a vertex state variable.
6-
- `comp`: the component index, either int or a collection of ints
6+
- `comp`: the component index, either int, symbol or a collection
77
- `sub`: the subindex, either int, symbol or a collection of those.
88
99
```
1010
VIndex(1, :P) # vertex 1, variable :P
1111
VIndex(1:5, 1) # first state of vertices 1 to 5
1212
VIndex(7, (:x,:y)) # states :x and :y of vertex 7
1313
VIndex(2) # references the second vertex model
14+
VIndex(:a) # references vertex with unique name :a
1415
```
1516
1617
Can be used to index into objects supporting the `SymbolicIndexingInterface`,
@@ -28,14 +29,16 @@ VIndex(ci::Union{Symbol,Int}) = VIndex(ci, nothing)
2829
idx = EIndex(comp, sub)
2930
3031
A symbolic index for an edge state variable.
31-
- `comp`: the component index, either int or a collection of ints
32+
- `comp`: the component index, either int, symbol, pair or a collection
3233
- `sub`: the subindex, either int, symbol or a collection of those.
3334
3435
```
3536
EIndex(1, :P) # edge 1, variable :P
3637
EIndex(1:5, 1) # first state of edges 1 to 5
3738
EIndex(7, (:x,:y)) # states :x and :y of edge 7
3839
EIndex(2) # references the second edge model
40+
EIndex(1=>2) # references edge from v1 to v2
41+
EIndex(:a=>:b) # references edge from (uniquely named) vertex :a to :b
3942
```
4043
4144
Can be used to index into objects supporting the `SymbolicIndexingInterface`,
@@ -53,7 +56,7 @@ EIndex(ci::Union{Symbol,Int}) = EIndex(ci, nothing)
5356
idx = VPIndex(comp, sub)
5457
5558
A symbolic index into the parameter a vertex:
56-
- `comp`: the component index, either int or a collection of ints
59+
- `comp`: the component index, either int, symbol or a collection
5760
- `sub`: the subindex, either int, symbol or a collection of those.
5861
5962
Can be used to index into objects supporting the `SymbolicIndexingInterface`,
@@ -69,8 +72,8 @@ end
6972
EPIndex{C,S} <: SymbolicStateIndex{C,S}
7073
idx = VEIndex(comp, sub)
7174
72-
A symbolic index into the parameter a vertex:
73-
- `comp`: the component index, either int or a collection of ints
75+
A symbolic index into the parameter of an edge:
76+
- `comp`: the component index, either int, symbol, pair or a collection
7477
- `sub`: the subindex, either int, symbol or a collection of those.
7578
7679
Can be used to index into objects supporting the `SymbolicIndexingInterface`,
@@ -95,18 +98,25 @@ SSI Maintainer assured that f.sys is really only used for symbolic indexig so me
9598
SciMLBase.__has_sys(nw::Network) = true
9699
Base.getproperty(nw::Network, s::Symbol) = s===:sys ? nw : getfield(nw, s)
97100

98-
SII.symbolic_type(::Type{<:SymbolicIndex{<:Union{Symbol,Int},<:Union{Symbol,Int}}}) = SII.ScalarSymbolic()
101+
SII.symbolic_type(::Type{<:SymbolicIndex{<:Union{<:Pair,Symbol,Int},<:Union{Symbol,Int}}}) = SII.ScalarSymbolic()
99102
SII.symbolic_type(::Type{<:SymbolicIndex}) = SII.ArraySymbolic()
100103

101104
SII.hasname(::SymbolicIndex) = false
102-
SII.hasname(::SymbolicIndex{<:Union{Symbol,Int},<:Union{Symbol,Int}}) = true
105+
SII.hasname(::SymbolicIndex{<:Union{<:Pair,Symbol,Int},<:Union{Symbol,Int}}) = true
103106
function SII.getname(x::SymbolicVertexIndex)
104107
prefix = x.compidx isa Int ? :v : Symbol()
105108
Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx))
106109
end
107110
function SII.getname(x::SymbolicEdgeIndex)
108-
prefix = x.compidx isa Int ? :e : Symbol()
109-
Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx))
111+
if x.compidx isa Pair
112+
src, dst = x.compidx
113+
_src = src isa Int ? Symbol(:v, src) : Symbol(src)
114+
_dst = dst isa Int ? Symbol(:v, dst) : Symbol(dst)
115+
Symbol(_src, "ₜₒ", _dst, :₊, Symbol(x.subidx))
116+
else
117+
prefix = x.compidx isa Int ? :e : Symbol()
118+
Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx))
119+
end
110120
end
111121

112122
resolvecompidx(nw::Network, sni) = resolvecompidx(nw.im, sni)
@@ -119,20 +129,49 @@ function resolvecompidx(im::IndexManager, sni::SymbolicIndex{Symbol})
119129
throw(ArgumentError("Could not resolve component index for $sni, the name might not be unique?"))
120130
end
121131
end
132+
function resolvecompidx(im::IndexManager, sni::SymbolicEdgeIndex{<:Pair})
133+
src, dst = sni.compidx
134+
135+
src_i = try
136+
resolvecompidx(im, VIndex(src))
137+
catch
138+
throw(ArgumentError("Could not resolve edge source $src"))
139+
end
140+
dst_i = try
141+
resolvecompidx(im, VIndex(dst))
142+
catch
143+
throw(ArgumentError("Could not resolve edge destination $dst"))
144+
end
145+
146+
eidx = findfirst(im.edgevec) do e
147+
e.src == src_i && e.dst == dst_i
148+
end
149+
if isnothing(eidx)
150+
reverse = findfirst(im.edgevec) do e
151+
e.src == dst_i && e.dst == src_i
152+
end
153+
err = "Invalid Index: Network does not contain edge from $(src) => $(dst)!"
154+
if !isnothing(reverse)
155+
err *= " Maybe you meant the reverse edge from $(dst) => $(src)?"
156+
end
157+
throw(ArgumentError(err))
158+
end
159+
return eidx
160+
end
122161
getcomp(nw::Network, sni) = getcomp(nw.im, sni)
123162
getcomp(im::IndexManager, sni::SymbolicEdgeIndex) = im.edgem[resolvecompidx(im, sni)]
124163
getcomp(im::IndexManager, sni::SymbolicVertexIndex) = im.vertexm[resolvecompidx(im, sni)]
125164

126165
getcomprange(nw::Network, sni) = getcomprange(nw.im, sni)
127166
getcomprange(im::IndexManager, sni::VIndex{<:Union{Symbol,Int}}) = im.v_data[resolvecompidx(im, sni)]
128-
getcomprange(im::IndexManager, sni::EIndex{<:Union{Symbol,Int}}) = im.e_data[resolvecompidx(im, sni)]
167+
getcomprange(im::IndexManager, sni::EIndex{<:Union{<:Pair,Symbol,Int}}) = im.e_data[resolvecompidx(im, sni)]
129168

130169
getcompoutrange(nw::Network, sni) = getcompoutrange(nw.im, sni)
131170
getcompoutrange(im::IndexManager, sni::VIndex{<:Union{Symbol,Int}}) = im.v_out[resolvecompidx(im, sni)]
132-
getcompoutrange(im::IndexManager, sni::EIndex{<:Union{Symbol,Int}}) = flatrange(im.e_out[resolvecompidx(im, sni)])
171+
getcompoutrange(im::IndexManager, sni::EIndex{<:Union{<:Pair,Symbol,Int}}) = flatrange(im.e_out[resolvecompidx(im, sni)])
133172

134173
getcompprange(nw::Network, sni::SymbolicVertexIndex{<:Union{Symbol,Int}}) = nw.im.v_para[resolvecompidx(nw, sni)]
135-
getcompprange(nw::Network, sni::SymbolicEdgeIndex{<:Union{Symbol,Int}}) = nw.im.e_para[resolvecompidx(nw, sni)]
174+
getcompprange(nw::Network, sni::SymbolicEdgeIndex{<:Union{<:Pair,Symbol,Int}}) = nw.im.e_para[resolvecompidx(nw, sni)]
136175

137176
subsym_has_idx(sym::Symbol, syms) = sym syms
138177
subsym_has_idx(idx::Int, syms) = 1 idx length(syms)
@@ -143,7 +182,7 @@ subsym_to_idx(idx::Int, _) = idx
143182
#### Iterator/Broadcast interface for ArraySymbolic types
144183
####
145184
# TODO: not broadcasting over idx with colon is weird
146-
Base.broadcastable(si::SymbolicIndex{<:Union{Int,Symbol,Colon},<:Union{Int,Symbol,Colon}}) = Ref(si)
185+
Base.broadcastable(si::SymbolicIndex{<:Union{Int,Symbol,<:Pair,Colon},<:Union{Int,Symbol,Colon}}) = Ref(si)
147186

148187
const _IterableComponent = SymbolicIndex{<:Union{AbstractVector,Tuple},<:Union{Int,Symbol}}
149188
Base.length(si::_IterableComponent) = length(si.compidx)
@@ -166,7 +205,7 @@ function Base.iterate(si::_IterableComponent, state=nothing)
166205
_similar(si, it[1], si.subidx), it[2]
167206
end
168207

169-
const _IterableSubcomponent = SymbolicIndex{<:Union{Symbol,Int},<:Union{AbstractVector,Tuple}}
208+
const _IterableSubcomponent = SymbolicIndex{<:Union{<:Pair,Symbol,Int},<:Union{AbstractVector,Tuple}}
170209
Base.length(si::_IterableSubcomponent) = length(si.subidx)
171210
Base.size(si::_IterableSubcomponent) = (length(si),)
172211
Base.IteratorSize(si::_IterableSubcomponent) = Base.HasShape{1}()
@@ -200,9 +239,9 @@ _resolve_colon(nw::Network, sni::EIndex{Colon}) = EIndex(1:ne(nw), sni.subidx)
200239
_resolve_colon(nw::Network, sni::VPIndex{Colon}) = VPIndex(1:nv(nw), sni.subidx)
201240
_resolve_colon(nw::Network, sni::EPIndex{Colon}) = EPIndex(1:ne(nw), sni.subidx)
202241
_resolve_colon(nw::Network, sni::VIndex{<:Union{Symbol,Int},Colon}) = VIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni)))
203-
_resolve_colon(nw::Network, sni::EIndex{<:Union{Symbol,Int},Colon}) = EIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni)))
242+
_resolve_colon(nw::Network, sni::EIndex{<:Union{<:Pair,Symbol,Int},Colon}) = EIndex{Int, UnitRange{Int}}(sni.compidx, 1:dim(getcomp(nw,sni)))
204243
_resolve_colon(nw::Network, sni::VPIndex{<:Union{Symbol,Int},Colon}) = VPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni)))
205-
_resolve_colon(nw::Network, sni::EPIndex{<:Union{Symbol,Int},Colon}) = EPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni)))
244+
_resolve_colon(nw::Network, sni::EPIndex{<:Union{<:Pair,Symbol,Int},Colon}) = EPIndex{Int, UnitRange{Int}}(sni.compidx, 1:pdim(getcomp(nw,sni)))
206245

207246

208247
#### Implmentation of index provider interface
@@ -220,7 +259,7 @@ SII.all_symbols(nw::Network) = vcat(SII.all_variable_symbols(nw), SII.parameter_
220259
####
221260
#### variable indexing
222261
####
223-
const POTENTIAL_SCALAR_SIDX = Union{SymbolicStateIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}}}
262+
const POTENTIAL_SCALAR_SIDX = Union{SymbolicStateIndex{<:Union{<:Pair,Symbol,Int},<:Union{Int,Symbol}}}
224263
function SII.is_variable(nw::Network, sni)
225264
if _hascolon(sni)
226265
SII.is_variable(nw, _resolve_colon(nw,sni))
@@ -277,8 +316,8 @@ end
277316
####
278317
# when using an number instead of symbol only PIndex is valid
279318
const POTENTIAL_SCALAR_PIDX = Union{
280-
SymbolicParameterIndex{<:Union{Symbol,Int},<:Union{Int,Symbol}},
281-
SymbolicIndex{<:Union{Symbol,Int},Symbol}
319+
SymbolicParameterIndex{<:Union{<:Pair,Symbol,Int},<:Union{Int,Symbol}},
320+
SymbolicIndex{<:Union{<:Pair,Symbol,Int},Symbol}
282321
}
283322
function SII.is_parameter(nw::Network, sni)
284323
if _hascolon(sni)
@@ -410,7 +449,7 @@ function SII.is_observed(nw::Network, sni)
410449
end
411450
end
412451
_is_observed(nw::Network, _) = false
413-
function _is_observed(nw::Network, sni::SymbolicStateIndex{<:Union{Symbol,Int},Symbol})
452+
function _is_observed(nw::Network, sni::SymbolicStateIndex{<:Union{<:Pair,Symbol,Int},Symbol})
414453
cf = getcomp(nw, sni)
415454
return sni.subidx obssym_all(cf)
416455
end
@@ -1212,7 +1251,7 @@ Base.getindex(s::NWState, idx::ObservableExpression) = SII.getu(s, idx)(s)
12121251
Base.getindex(s::NWParameter, idx::ObservableExpression) = SII.getp(s, idx)(s)
12131252

12141253
# using getindex to access component models
1215-
function Base.getindex(nw::Network, i::EIndex{<:Union{Symbol,Int}, Nothing})
1254+
function Base.getindex(nw::Network, i::EIndex{<:Union{<:Pair,Symbol,Int}, Nothing})
12161255
return getcomp(nw, i)
12171256
end
12181257
function Base.getindex(nw::Network, i::VIndex{<:Union{Symbol,Int}, Nothing})

test/symbolicindexing_test.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,66 @@ end
552552
b = @b $obsf2($(rand(dim(nw))), $(rand(pdim(nw))), NaN, $(zeros(length(idxs2)))) # 17ns 0 allocs
553553
@test b.allocs == 0
554554
end
555+
556+
@testset "test edge indexing with Pair syntax" begin
557+
# Create a simple network with named vertices for testing
558+
v1 = Lib.kuramoto_second(name=:v1)
559+
v2 = Lib.kuramoto_second(name=:v2)
560+
v3 = Lib.kuramoto_second(name=:v3)
561+
e12 = Lib.kuramoto_edge(name=:e12)
562+
e23 = Lib.kuramoto_edge(name=:e23)
563+
g = path_graph(3)
564+
nw = Network(g, [v1, v2, v3], [e12, e23])
565+
566+
# Basic pair syntax tests
567+
@test SII.is_observed(nw, EIndex(1=>2, :P))
568+
@test SII.is_observed(nw, EIndex(:v1=>:v2, :P))
569+
@test SII.is_parameter(nw, EPIndex(1=>2, :K))
570+
@test SII.is_parameter(nw, EPIndex(:v1=>:v2, :K))
571+
572+
# Test that pairs resolve to regular indices
573+
@test SII.parameter_index(nw, EPIndex(1=>2, :K)) == SII.parameter_index(nw, EPIndex(1, :K))
574+
@test SII.parameter_index(nw, EPIndex(:v1=>:v2, :K)) == SII.parameter_index(nw, EPIndex(1, :K))
575+
@test SII.parameter_index(nw, EPIndex(:v1=>2, :K)) == SII.parameter_index(nw, EPIndex(1, :K))
576+
577+
# Error cases
578+
@test_throws ArgumentError SII.parameter_index(nw, EPIndex(1=>3, :K)) # no direct edge
579+
@test_throws ArgumentError SII.parameter_index(nw, EPIndex(:nonexistent=>:v2, :K)) # invalid vertex
580+
581+
# Test with solution and state objects
582+
u0 = rand(dim(nw))
583+
p = rand(pdim(nw))
584+
prob = ODEProblem(nw, u0, (0.0, 1.0), p)
585+
sol = solve(prob, Tsit5())
586+
s = NWState(nw, u0, p)
587+
588+
# Solution indexing
589+
@test sol([0.1, 0.5], idxs=EIndex(1=>2, :P)).u sol([0.1, 0.5], idxs=EIndex(1, :P)).u
590+
@test sol([0.1, 0.5], idxs=EIndex(:v1=>:v2, :P)).u sol([0.1, 0.5], idxs=EIndex(1, :P)).u
591+
592+
# State access
593+
@test s[EIndex(1=>2, :P)] == s[EIndex(1, :P)]
594+
@test s[EIndex(:v1=>:v2, :P)] == s[EIndex(1, :P)]
595+
@test s[EPIndex(1=>2, :K)] == s[EPIndex(1, :K)]
596+
@test s[EPIndex(:v1=>:v2, :K)] == s[EPIndex(1, :K)]
597+
598+
# Proxy syntax
599+
@test s.e[1=>2, :P] == s[EIndex(1, :P)]
600+
@test s.e[:v1=>:v2, :P] == s[EIndex(1, :P)]
601+
s.p.e[1=>2, :K] = 2.71
602+
@test s.p.e[1, :K] == 2.71
603+
@test s.p[EPIndex(1=>2, :K)] == 2.71
604+
605+
# Multiple edge access
606+
edges_int = [EIndex(1=>2, :P), EIndex(2=>3, :P)]
607+
edges_named = [EIndex(:v1=>:v2, :P), EIndex(:v2=>:v3, :P)]
608+
edges_regular = [EIndex(1, :P), EIndex(2, :P)]
609+
@test s[edges_int] == s[edges_regular]
610+
@test s[edges_named] == s[edges_regular]
611+
612+
# naming
613+
@test SII.getname(EIndex(1=>2, :P)) == :v1ₜₒv2₊P
614+
@test SII.getname(EIndex(:a=>:b, :P)) == :aₜₒb₊P
615+
@test SII.getname(EIndex(:a=>2, :P)) == :aₜₒv2₊P
616+
@test SII.getname(EIndex(1=>:b, :P)) == :v1ₜₒb₊P
617+
end

0 commit comments

Comments
 (0)