Skip to content

Commit 4e88944

Browse files
iteration and broadcast
1 parent 909df5a commit 4e88944

File tree

6 files changed

+130
-1670
lines changed

6 files changed

+130
-1670
lines changed

GNNGraphs/docs/src/guides/temporalgraph.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,29 @@ GNNGraph:
8585
num_nodes: 10
8686
num_edges: 16
8787
```
88+
## Iteration and Broadcasting
89+
90+
Iteration and broadcasting over a temporal graph is similar to that of a vector of snapshots:
91+
92+
```jldoctest temporal
93+
julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)];
94+
95+
julia> tg = TemporalSnapshotsGNNGraph(snapshots);
96+
97+
julia> [g for g in tg] # iterate over snapshots
98+
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
99+
GNNGraph(10, 20) with no data
100+
GNNGraph(10, 14) with no data
101+
GNNGraph(10, 22) with no data
102+
103+
julia> f(g) = g isa GNNGraph;
104+
105+
julia> f.(tg) # broadcast over snapshots
106+
3-element BitVector:
107+
1
108+
1
109+
1
110+
```
88111

89112
## Basic Queries
90113

GNNGraphs/src/temporalsnapshotsgnngraph.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,13 @@ end
9191

9292
function Base.length(tg::TemporalSnapshotsGNNGraph)
9393
return tg.num_snapshots
94-
end
94+
end
95+
96+
# Allow broadcasting over the temporal snapshots
97+
Base.broadcastable(tg::TemporalSnapshotsGNNGraph) = tg.snapshots
98+
99+
Base.iterate(tg::TemporalSnapshotsGNNGraph) = Base.iterate(tg.snapshots)
100+
Base.iterate(tg::TemporalSnapshotsGNNGraph, i) = Base.iterate(tg.snapshots, i)
95101

96102
function Base.setindex!(tg::TemporalSnapshotsGNNGraph, g::GNNGraph, t::Int)
97103
tg.snapshots[t] = g

GNNGraphs/test/temporalsnapshotsgnngraph.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#TODO add graph_type = GRAPH_TYPE to all constructor calls
2+
13
@testset "Constructor array TemporalSnapshotsGNNGraph" begin
24
snapshots = [rand_graph(10, 20) for i in 1:5]
35
tg = TemporalSnapshotsGNNGraph(snapshots)
@@ -12,6 +14,7 @@
1214
@test tg.num_snapshots == 5
1315
end
1416

17+
1518
@testset "==" begin
1619
snapshots = [rand_graph(10, 20) for i in 1:5]
1720
tsg1 = TemporalSnapshotsGNNGraph(snapshots)
@@ -41,7 +44,7 @@ end
4144
end
4245

4346
@testset "getproperty" begin
44-
x = rand(10)
47+
x = rand(Float32, 10)
4548
snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5]
4649
tsg = TemporalSnapshotsGNNGraph(snapshots)
4750
@test tsg.tgdata == DataStore()
@@ -111,18 +114,31 @@ end
111114
@testset "show" begin
112115
snapshots = [rand_graph(10, 20) for i in 1:5]
113116
tsg = TemporalSnapshotsGNNGraph(snapshots)
114-
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data"
115-
@test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data"
117+
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)"
118+
@test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5)"
116119
@test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5"
117-
tsg.tgdata.x=rand(4)
118-
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data"
120+
tsg.tgdata.x = rand(Float32, 4)
121+
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)"
122+
end
123+
124+
@testset "broadcastable" begin
125+
snapshots = [rand_graph(10, 20) for i in 1:5]
126+
tsg = TemporalSnapshotsGNNGraph(snapshots)
127+
f(g) = g isa GNNGraph
128+
@test f.(tsg) == trues(5)
129+
end
130+
131+
@testset "iterate" begin
132+
snapshots = [rand_graph(10, 20) for i in 1:5]
133+
tsg = TemporalSnapshotsGNNGraph(snapshots)
134+
@test [g for g in tsg] isa Vector{<:GNNGraph}
119135
end
120136

121137
if TEST_GPU
122138
@testset "gpu" begin
123-
snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5]
139+
snapshots = [rand_graph(10, 20; ndata = rand(Float32, 5,10)) for i in 1:5]
124140
tsg = TemporalSnapshotsGNNGraph(snapshots)
125-
tsg.tgdata.x = rand(5)
141+
tsg.tgdata.x = rand(Float32, 5)
126142
dev = CUDADevice() #TODO replace with `gpu_device()`
127143
tsg = tsg |> dev
128144
@test tsg.snapshots[1].ndata.x isa CuArray

GNNLux/docs/src_tutorials/gnn_intro.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ visualize_embeddings(emb_init, colors = labels)
220220
# If you are not new to Lux, this scheme should appear familiar to you.
221221

222222
# Note that our semi-supervised learning scenario is achieved by the following line:
223-
# ```
223+
# ```julia
224224
# logitcrossentropy(ŷ[:,train_mask], y[:,train_mask])
225225
# ```
226226
# While we compute node embeddings for all of our nodes, we **only make use of the training nodes for computing the loss**.

0 commit comments

Comments
 (0)