Skip to content

Commit 65c0faa

Browse files
authored
Add TemporalSnapshotsGNNGraph struct (#293)
* Add `TemporalSnapshotsGNNGraph` struct * Remove typo * Add `==` function * Add `add/remove_snaposhot` and `show` functions * Add test * Export `TemporalSnapshotsGNNgraph` functions * Add temporalsnapshotsgnngraph tests * Rename file and function * Fix comma * Add test `show`
1 parent d394b91 commit 65c0faa

File tree

4 files changed

+161
-0
lines changed

4 files changed

+161
-0
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ export GNNGraph,
3030
include("gnnheterograph.jl")
3131
export GNNHeteroGraph
3232

33+
34+
include("temporalsnapshotsgnngraph.jl")
35+
export TemporalSnapshotsGNNGraph,
36+
add_snapshot,
37+
remove_snapshot
38+
3339
include("query.jl")
3440
export adjacency_list,
3541
edge_index,
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
struct TemporalSnapshotsGNNGraph
2+
num_nodes::Vector{Int}
3+
num_edges::Vector{Int}
4+
num_snapshots::Int
5+
snapshots::Vector{<:GNNGraph}
6+
tgdata::DataStore
7+
end
8+
9+
function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
10+
@assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes"
11+
return TemporalSnapshotsGNNGraph(
12+
[s.num_nodes for s in snapshots],
13+
[s.num_edges for s in snapshots],
14+
length(snapshots),
15+
snapshots,
16+
DataStore()
17+
)
18+
end
19+
20+
function Base.:(==)(tsg1::TemporalSnapshotsGNNGraph, tsg2::TemporalSnapshotsGNNGraph)
21+
tsg1 === tsg2 && return true
22+
for k in fieldnames(typeof(tsg1))
23+
getfield(tsg1, k) != getfield(tsg2, k) && return false
24+
end
25+
return true
26+
end
27+
28+
function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int)
29+
return tg.snapshots[t]
30+
end
31+
32+
function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector)
33+
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata)
34+
end
35+
36+
function add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph)
37+
@assert g.num_nodes == tg.num_nodes[t] "number of nodes must match"
38+
num_nodes= tg.num_nodes
39+
num_edges = tg.num_edges
40+
snapshots = tg.snapshots
41+
num_snapshots = tg.num_snapshots + 1
42+
insert!(num_nodes, t, g.num_nodes)
43+
insert!(num_edges, t, g.num_edges)
44+
insert!(snapshots, t, g)
45+
return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata)
46+
end
47+
48+
function remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int)
49+
num_nodes= tg.num_nodes
50+
num_edges = tg.num_edges
51+
snapshots = tg.snapshots
52+
num_snapshots = tg.num_snapshots - 1
53+
deleteat!(num_nodes, t)
54+
deleteat!(num_edges, t)
55+
deleteat!(snapshots, t)
56+
return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata)
57+
end
58+
59+
function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph)
60+
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
61+
print_feature_t(io, tsg.tgdata)
62+
print(io, " data")
63+
end
64+
65+
function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph)
66+
if get(io, :compact, false)
67+
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
68+
print_feature_t(io, tsg.tgdata)
69+
print(io, " data")
70+
else
71+
print(io,
72+
"TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)")
73+
if !isempty(tsg.tgdata)
74+
print(io, "\n tgdata:")
75+
for k in keys(tsg.tgdata)
76+
print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))")
77+
end
78+
end
79+
end
80+
end
81+
82+
83+
function print_feature_t(io::IO, feature)
84+
if !isempty(feature)
85+
if length(keys(feature)) == 1
86+
k = first(keys(feature))
87+
v = first(values(feature))
88+
print(io, "$(k): $(dims2string(size(v)))")
89+
else
90+
print(io, "(")
91+
for (i, (k, v)) in enumerate(pairs(feature))
92+
print(io, "$k: $(dims2string(size(v)))")
93+
if i == length(feature)
94+
print(io, ")")
95+
else
96+
print(io, ", ")
97+
end
98+
end
99+
end
100+
else
101+
print(io, "no")
102+
end
103+
end
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
@testset "Constructor array TemporalSnapshotsGNNGraph" begin
2+
snapshots = [rand_graph(10, 20) for i in 1:5]
3+
tsg = TemporalSnapshotsGNNGraph(snapshots)
4+
@test tsg.num_nodes == [10 for i in 1:5]
5+
@test tsg.num_edges == [20 for i in 1:5]
6+
wrsnapshots = [rand_graph(10,20), rand_graph(12,22)]
7+
@test_throws AssertionError TemporalSnapshotsGNNGraph(wrsnapshots)
8+
end
9+
10+
@testset "==" begin
11+
snapshots = [rand_graph(10, 20) for i in 1:5]
12+
tsg1 = TemporalSnapshotsGNNGraph(snapshots)
13+
tsg2 = TemporalSnapshotsGNNGraph(snapshots)
14+
@test tsg1 == tsg2
15+
tsg3 = TemporalSnapshotsGNNGraph(snapshots[1:3])
16+
@test tsg1 != tsg3
17+
@test tsg1 !== tsg3
18+
end
19+
20+
@testset "getindex" begin
21+
snapshots = [rand_graph(10, 20) for i in 1:5]
22+
tsg = TemporalSnapshotsGNNGraph(snapshots)
23+
@test tsg[3] == snapshots[3]
24+
@test tsg[[1,2]] == TemporalSnapshotsGNNGraph([10,10], [20,20], 2, snapshots[1:2], tsg.tgdata)
25+
end
26+
27+
@testset "add/remove_snapshot" begin
28+
snapshots = [rand_graph(10, 20) for i in 1:5]
29+
tsg = TemporalSnapshotsGNNGraph(snapshots)
30+
g = rand_graph(10, 20)
31+
tsg = add_snapshot(tsg, 3, g)
32+
@test tsg.num_nodes == [10 for i in 1:6]
33+
@test tsg.num_edges == [20 for i in 1:6]
34+
@test tsg.snapshots[3] == g
35+
tsg = remove_snapshot(tsg, 3)
36+
@test tsg.num_nodes == [10 for i in 1:5]
37+
@test tsg.num_edges == [20 for i in 1:5]
38+
@test tsg.snapshots == snapshots
39+
end
40+
41+
@testset "show" begin
42+
snapshots = [rand_graph(10, 20) for i in 1:5]
43+
tsg = TemporalSnapshotsGNNGraph(snapshots)
44+
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data"
45+
@test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data"
46+
@test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5"
47+
tsg.tgdata.x=rand(4)
48+
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data"
49+
end
50+
51+
# @test sprint(show, MIME("text/plain"), rand_graph(10, 20); context=:compact => true) == "GNNGraph(10, 20) with no data"

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ tests = [
3434
"GNNGraphs/query",
3535
"GNNGraphs/sampling",
3636
"GNNGraphs/gnnheterograph",
37+
"GNNGraphs/temporalsnapshotsgnngraph",
3738
"utils",
3839
"msgpass",
3940
"layers/basic",

0 commit comments

Comments
 (0)