Skip to content

Commit ce2c266

Browse files
revisiting TemporalSnapshotsGNNGraph
1 parent 27d13c8 commit ce2c266

File tree

2 files changed

+79
-70
lines changed

2 files changed

+79
-70
lines changed

GNNGraphs/src/temporalsnapshotsgnngraph.jl

Lines changed: 58 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,68 @@
11
"""
2-
TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
2+
TemporalSnapshotsGNNGraph(snapshots)
33
4-
A type representing a temporal graph as a sequence of snapshots. In this case a snapshot is a [`GNNGraph`](@ref).
4+
A type representing a time-varying graph as a sequence of snapshots,
5+
each snapshot being a [`GNNGraph`](@ref).
56
6-
`TemporalSnapshotsGNNGraph` can store the feature array associated to the graph itself as a [`DataStore`](@ref) object,
7-
and it uses the [`DataStore`](@ref) objects of each snapshot for the node and edge features.
8-
The features can be passed at construction time or added later.
7+
The argument `snapshots` is a collection of `GNNGraph`s with arbitrary
8+
number of nodes and edges each.
99
10-
# Constructor Arguments
10+
Calling `tg` the temporal graph, `tg[t]` returns the `t`-th snapshot.
1111
12-
- `snapshot`: a vector of snapshots, where each snapshot must have the same number of nodes.
12+
The snapshots can contain node/edge/graph features, while global features for the
13+
whole temporal sequence can be stored in `tg.tgdata`.
1314
14-
# Examples
15+
See [`add_snapshot`](@ref) and [`remove_snapshot`](@ref) for adding and removing snapshots.
1516
16-
```julia
17-
julia> using GNNGraphs
17+
# Examples
1818
19-
julia> snapshots = [rand_graph(10,20) for i in 1:5];
19+
```jldoctest
20+
julia> snapshots = [rand_graph(i , 2*i) for i in 10:10:50];
2021
2122
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
2223
TemporalSnapshotsGNNGraph:
23-
num_nodes: [10, 10, 10, 10, 10]
24-
num_edges: [20, 20, 20, 20, 20]
24+
num_nodes: [10, 20, 30, 40, 50]
25+
num_edges: [20, 40, 60, 80, 100]
2526
num_snapshots: 5
2627
27-
julia> tg.tgdata.x = rand(4); # add temporal graph feature
28+
julia> tg.num_snapshots
29+
5
2830
29-
julia> tg # show temporal graph with new feature
31+
julia> tg.num_nodes
32+
5-element Vector{Int64}:
33+
10
34+
20
35+
30
36+
40
37+
50
38+
39+
julia> tg[1]
40+
GNNGraph:
41+
num_nodes: 10
42+
num_edges: 20
43+
44+
julia> tg[2:3]
3045
TemporalSnapshotsGNNGraph:
31-
num_nodes: [10, 10, 10, 10, 10]
32-
num_edges: [20, 20, 20, 20, 20]
33-
num_snapshots: 5
34-
tgdata:
35-
x = 4-element Vector{Float64}
46+
num_nodes: [20, 30]
47+
num_edges: [40, 60]
48+
num_snapshots: 2
3649
```
3750
"""
38-
struct TemporalSnapshotsGNNGraph
39-
num_nodes::AbstractVector{Int}
40-
num_edges::AbstractVector{Int}
51+
struct TemporalSnapshotsGNNGraph{G<:GNNGraph, D<:DataStore}
52+
num_nodes::Vector{Int}
53+
num_edges::Vector{Int}
4154
num_snapshots::Int
42-
snapshots::AbstractVector{<:GNNGraph}
43-
tgdata::DataStore
55+
snapshots::Vector{G}
56+
tgdata::D
4457
end
4558

46-
function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
47-
@assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes"
59+
function TemporalSnapshotsGNNGraph(snapshots)
60+
snapshots = collect(snapshots)
4861
return TemporalSnapshotsGNNGraph(
4962
[s.num_nodes for s in snapshots],
5063
[s.num_edges for s in snapshots],
5164
length(snapshots),
52-
snapshots,
65+
collect(snapshots),
5366
DataStore()
5467
)
5568
end
@@ -67,7 +80,19 @@ function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int)
6780
end
6881

6982
function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector)
70-
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata)
83+
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t],
84+
length(t), tg.snapshots[t], tg.tgdata)
85+
end
86+
87+
function Base.length(tg::TemporalSnapshotsGNNGraph)
88+
return tg.num_snapshots
89+
end
90+
91+
function Base.setindex!(tg::TemporalSnapshotsGNNGraph, g::GNNGraph, t::Int)
92+
tg.snapshots[t] = g
93+
tg.num_nodes[t] = g.num_nodes
94+
tg.num_edges[t] = g.num_edges
95+
return tg
7196
end
7297

7398
"""
@@ -78,8 +103,6 @@ Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the sn
78103
# Examples
79104
80105
```jldoctest
81-
julia> using GNNGraphs
82-
83106
julia> snapshots = [rand_graph(10, 20) for i in 1:5];
84107
85108
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
@@ -185,14 +208,8 @@ end
185208
function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol)
186209
if prop fieldnames(TemporalSnapshotsGNNGraph)
187210
return getfield(tg, prop)
188-
elseif prop == :ndata
189-
return [s.ndata for s in tg.snapshots]
190-
elseif prop == :edata
191-
return [s.edata for s in tg.snapshots]
192-
elseif prop == :gdata
193-
return [s.gdata for s in tg.snapshots]
194-
else
195-
return [getproperty(s,prop) for s in tg.snapshots]
211+
else
212+
return [getproperty(s, prop) for s in tg.snapshots]
196213
end
197214
end
198215

@@ -204,39 +221,15 @@ end
204221

205222
function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph)
206223
if get(io, :compact, false)
207-
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
208-
print_feature_t(io, tsg.tgdata)
209-
print(io, " data")
224+
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots))")
210225
else
211226
print(io,
212227
"TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)")
213228
if !isempty(tsg.tgdata)
214229
print(io, "\n tgdata:")
215230
for k in keys(tsg.tgdata)
216-
print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))")
217-
end
218-
end
219-
end
220-
end
221-
222-
function print_feature_t(io::IO, feature)
223-
if !isempty(feature)
224-
if length(keys(feature)) == 1
225-
k = first(keys(feature))
226-
v = first(values(feature))
227-
print(io, "$(k): $(dims2string(size(v)))")
228-
else
229-
print(io, "(")
230-
for (i, (k, v)) in enumerate(pairs(feature))
231-
print(io, "$k: $(dims2string(size(v)))")
232-
if i == length(feature)
233-
print(io, ")")
234-
else
235-
print(io, ", ")
236-
end
231+
print(io, "\n $k = $(shortsummary(tsg.tgdata[k]))")
237232
end
238233
end
239-
else
240-
print(io, "no")
241234
end
242235
end

GNNGraphs/test/temporalsnapshotsgnngraph.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
@testset "Constructor array TemporalSnapshotsGNNGraph" begin
22
snapshots = [rand_graph(10, 20) for i in 1:5]
3-
tsg = TemporalSnapshotsGNNGraph(snapshots)
4-
@test tsg.num_nodes == [10 for i in 1:5]
5-
@test tsg.num_edges == [20 for i in 1:5]
6-
wrsnapshots = [rand_graph(10,20), rand_graph(12,22)]
7-
@test_throws AssertionError TemporalSnapshotsGNNGraph(wrsnapshots)
3+
tg = TemporalSnapshotsGNNGraph(snapshots)
4+
@test tg.num_nodes == [10 for i in 1:5]
5+
@test tg.num_edges == [20 for i in 1:5]
6+
@test tg.num_snapshots == 5
7+
8+
snapshots = [rand_graph(i, 2*i) for i in 10:10:50]
9+
tg = TemporalSnapshotsGNNGraph(snapshots)
10+
@test tg.num_nodes == [i for i in 10:10:50]
11+
@test tg.num_edges == [2*i for i in 10:10:50]
12+
@test tg.num_snapshots == 5
813
end
914

1015
@testset "==" begin
@@ -24,6 +29,17 @@ end
2429
@test tsg[[1,2]] == TemporalSnapshotsGNNGraph([10,10], [20,20], 2, snapshots[1:2], tsg.tgdata)
2530
end
2631

32+
@testset "setindex!" begin
33+
snapshots = [rand_graph(10, 20) for i in 1:5]
34+
tsg = TemporalSnapshotsGNNGraph(snapshots)
35+
g = rand_graph(20, 40)
36+
tsg[3] = g
37+
@test tsg.snapshots[3] === g
38+
@test tsg.num_nodes == [10, 10, 20, 10, 10]
39+
@test tsg.num_edges == [20, 20, 40, 20, 20]
40+
@test_throws MethodError tsg[3:4] = g
41+
end
42+
2743
@testset "getproperty" begin
2844
x = rand(10)
2945
snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5]

0 commit comments

Comments
 (0)