Skip to content

Commit 05daca6

Browse files
dataloader support for vector of graphs (#143)
1 parent f935e8d commit 05daca6

File tree

8 files changed

+127
-84
lines changed

8 files changed

+127
-84
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
Manifest.toml
55
/docs/build/
66
.vscode
7+
LocalPreferences.toml

docs/src/gnngraph.md

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,38 +144,62 @@ julia> get_edge_weight(g)
144144
## Batches and Subgraphs
145145

146146
Multiple `GNNGraph`s can be batched togheter into a single graph
147-
containing the total number of the original nodes
147+
that contains the total number of the original nodes
148148
and where the original graphs are disjoint subgraphs.
149149

150150
```julia
151151
using Flux
152+
using Flux.Data: DataLoader
152153

153-
gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Float32,3,10)) for _ in 1:160])
154+
data = [rand_graph(10, 30, ndata=rand(Float32, 3, 10)) for _ in 1:160]
155+
gall = Flux.batch(data)
154156

157+
# gall is a GNNGraph containing many graphs
155158
@assert gall.num_graphs == 160
156159
@assert gall.num_nodes == 1600 # 10 nodes x 160 graphs
157160
@assert gall.num_edges == 9600 # 30 undirected edges x 2 directions x 160 graphs
158161

162+
# Let's create a mini-batch from gall
159163
g23, _ = getgraph(gall, 2:3)
160164
@assert g23.num_graphs == 2
161165
@assert g23.num_nodes == 20 # 10 nodes x 160 graphs
162166
@assert g23.num_edges == 120 # 30 undirected edges x 2 directions x 2 graphs x
163167

164-
165-
# DataLoader compatibility
166-
train_loader = Flux.Data.DataLoader(gall, batchsize=16, shuffle=true)
168+
# We can pass a GNNGraph to Flux's DataLoader
169+
train_loader = DataLoader(gall, batchsize=16, shuffle=true)
167170

168171
for g in train_loader
169172
@assert g.num_graphs == 16
170173
@assert g.num_nodes == 160
171174
@assert size(g.ndata.x) = (3, 160)
172-
.....
175+
# .....
173176
end
174177

175178
# Access the nodes' graph memberships
176179
graph_indicator(gall)
177180
```
178181

182+
## DataLoader and mini-batch iteration
183+
184+
While constructing a batched graph and passing it to the `DataLoader` is always
185+
an option for mini-batch iteration, the recommended way is
186+
to pass an array of graphs directly:
187+
188+
```julia
189+
using Flux.Data: DataLoader
190+
191+
data = [rand_graph(10, 30, ndata=rand(Float32, 3, 10)) for _ in 1:320]
192+
193+
train_loader = DataLoader(data, batchsize=16, shuffle=true)
194+
195+
for g in train_loader
196+
@assert g.num_graphs == 16
197+
@assert g.num_nodes == 160
198+
@assert size(g.ndata.x) = (3, 160)
199+
# .....
200+
end
201+
```
202+
179203
## Graph Manipulation
180204

181205
```julia

docs/src/index.md

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,64 +23,50 @@ Usage examples on real datasets can be found in the [examples](https://github.co
2323

2424
### Data preparation
2525

26-
First, we create our dataset consisting in multiple random graphs and associated data features.
27-
Then we batch the graphs together into a unique graph.
26+
We create a dataset consisting in multiple random graphs and associated data features.
2827

2928
```julia
30-
julia> using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics
31-
32-
julia> all_graphs = GNNGraph[];
33-
34-
julia> for _ in 1:1000
35-
g = GNNGraph(random_regular_graph(10, 4),
36-
ndata=(; x = randn(Float32, 16,10)), # input node features
37-
gdata=(; y = randn(Float32))) # regression target
38-
push!(all_graphs, g)
39-
end
40-
41-
julia> gbatch = Flux.batch(all_graphs)
42-
GNNGraph:
43-
num_nodes = 10000
44-
num_edges = 40000
45-
num_graphs = 1000
46-
ndata:
47-
x => (16, 10000)
48-
gdata:
49-
y => (1000,)
50-
```
29+
using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics
30+
using Flux.Data: DataLoader
5131

32+
all_graphs = GNNGraph[]
33+
34+
for _ in 1:1000
35+
g = GNNGraph(random_regular_graph(10, 4),
36+
ndata=(; x = randn(Float32, 16,10)), # input node features
37+
gdata=(; y = randn(Float32))) # regression target
38+
push!(all_graphs, g)
39+
end
40+
```
5241

5342
### Model building
5443

55-
We concisely define our model as a [`GNNChain`](@ref) containing 2 graph convolutional
56-
layers. If CUDA is available, our model will live on the gpu.
44+
We concisely define our model as a [`GNNChain`](@ref) containing two graph convolutional layers. If CUDA is available, our model will live on the gpu.
5745

5846
```julia
59-
julia> device = CUDA.functional() ? Flux.gpu : Flux.cpu;
60-
61-
julia> model = GNNChain(GCNConv(16 => 64),
62-
BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension)
63-
x -> relu.(x),
64-
GCNConv(64 => 64, relu),
65-
GlobalPool(mean), # aggregate node-wise features into graph-wise features
66-
Dense(64, 1)) |> device;
47+
device = CUDA.functional() ? Flux.gpu : Flux.cpu;
6748

68-
julia> ps = Flux.params(model);
49+
model = GNNChain(GCNConv(16 => 64),
50+
BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension)
51+
x -> relu.(x),
52+
GCNConv(64 => 64, relu),
53+
GlobalPool(mean), # aggregate node-wise features into graph-wise features
54+
Dense(64, 1)) |> device
6955

70-
julia> opt = ADAM(1f-4);
56+
ps = Flux.params(model)
57+
opt = ADAM(1f-4)
7158
```
7259

7360
### Training
7461

7562
Finally, we use a standard Flux training pipeline to fit our dataset.
76-
Flux's DataLoader iterates over mini-batches of graphs
63+
Flux's `DataLoader` iterates over mini-batches of graphs
7764
(batched together into a `GNNGraph` object).
7865

7966
```julia
80-
gtrain = getgraph(gbatch, 1:800)
81-
gtest = getgraph(gbatch, 801:gbatch.num_graphs)
82-
train_loader = Flux.Data.DataLoader(gtrain, batchsize=32, shuffle=true)
83-
test_loader = Flux.Data.DataLoader(gtest, batchsize=32, shuffle=false)
67+
train_size = round(Int, 0.8 * length(all_graphs))
68+
train_loader = DataLoader(all_graphs[1:train_size], batchsize=32, shuffle=true)
69+
test_loader = DataLoader(all_graphs[train_size+1:end], batchsize=32, shuffle=false)
8470

8571
loss(g::GNNGraph) = mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2)
8672

examples/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1010
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1111
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
13+
14+
[extras]
15+
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"

examples/graph_classification_tudataset.jl

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@ function eval_loss_accuracy(model, data_loader, device)
2727
end
2828

2929
function getdataset()
30-
data = TUDataset("MUTAG")
30+
tudata = TUDataset("MUTAG")
3131

32-
x = Array{Float32}(onehotbatch(data.node_labels, 0:6))
33-
y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2
32+
x = Array{Float32}(onehotbatch(tudata.node_labels, 0:6))
33+
y = (1 .+ Array{Float32}(tudata.graph_labels)) ./ 2
3434
@assert all(([0,1]), y) # binary classification
35-
# The dataset also has edge features but we won't be using them
36-
e = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
3735

38-
return GNNGraph(data.source, data.target,
39-
num_nodes=data.num_nodes,
40-
graph_indicator=data.graph_indicator,
41-
ndata=(; x), edata=(; e), gdata=(; y))
36+
## The dataset also has edge features but we won't be using them
37+
# e = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
38+
39+
gall = GNNGraph(tudata.source, tudata.target,
40+
num_nodes=tudata.num_nodes,
41+
graph_indicator=tudata.graph_indicator,
42+
ndata=(; x), gdata=(; y))
43+
44+
return [getgraph(gall, i) for i=1:gall.num_graphs]
4245
end
4346

4447
# arguments for the `train` function
@@ -66,23 +69,17 @@ function train(; kws...)
6669
end
6770

6871
# LOAD DATA
69-
70-
7172
NUM_TRAIN = 150
7273

73-
gfull = getdataset()
74-
75-
@info gfull
74+
data = getdataset()
75+
shuffle!(data)
7676

77-
perm = randperm(gfull.num_graphs)
78-
gtrain = getgraph(gfull, perm[1:NUM_TRAIN])
79-
gtest = getgraph(gfull, perm[NUM_TRAIN+1:end])
80-
train_loader = DataLoader(gtrain, batchsize=args.batchsize, shuffle=true)
81-
test_loader = DataLoader(gtest, batchsize=args.batchsize, shuffle=false)
77+
train_loader = DataLoader(data[1:NUM_TRAIN], batchsize=args.batchsize, shuffle=true)
78+
test_loader = DataLoader(data[NUM_TRAIN+1:end], batchsize=args.batchsize, shuffle=false)
8279

8380
# DEFINE MODEL
8481

85-
nin = size(gtrain.ndata.x, 1)
82+
nin = size(data[1].ndata.x, 1)
8683
nhidden = args.nhidden
8784

8885
model = GNNChain(GraphConv(nin => nhidden, relu),
@@ -94,7 +91,6 @@ function train(; kws...)
9491
ps = Flux.params(model)
9592
opt = ADAM(args.η)
9693

97-
9894
# LOGGING FUNCTION
9995

10096
function report(epoch)

src/GNNGraphs/gnngraph.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,27 @@ LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)
231231
Flux.Data._nobs(g::GNNGraph) = g.num_graphs
232232
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)
233233

234+
# DataLoader compatibility passing a vector of graphs and
235+
# effectively using `batch` as a collated function.
236+
StatsBase.nobs(data::Vector{<:GNNGraph}) = length(data)
237+
LearnBase.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
238+
LearnBase.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])
239+
Flux.Data._nobs(g::Vector{<:GNNGraph}) = StatsBase.nobs(g)
240+
Flux.Data._getobs(g::Vector{<:GNNGraph}, i) = LearnBase.getobs(g, i)
241+
242+
234243
#########################
235244

236245
function Base.:(==)(g1::GNNGraph, g2::GNNGraph)
237246
g1 === g2 && return true
238-
all(k -> getfield(g1, k) == getfield(g2, k), fieldnames(typeof(g1)))
247+
for k in fieldnames(typeof(g1))
248+
k === :graph_indicator && continue
249+
getfield(g1, k) != getfield(g2, k) && return false
250+
end
251+
return true
239252
end
240253

241254
function Base.hash(g::T, h::UInt) where T<:GNNGraph
242-
fs = (getfield(g, k) for k in fieldnames(typeof(g)))
255+
fs = (getfield(g, k) for k in fieldnames(typeof(g)) if k !== :graph_indicator)
243256
return foldl((h, f) -> hash(f, h), fs, init=hash(T, h))
244257
end

test/GNNGraphs/gnngraph.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,22 +229,34 @@
229229
# Attach non array data
230230
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
231231
@test g.edata.e == "ciao"
232-
end
232+
end
233233

234234
@testset "LearnBase and DataLoader compat" begin
235235
n, m, num_graphs = 10, 30, 50
236236
X = rand(10, n)
237-
E = rand(10, 2m)
237+
E = rand(10, m)
238238
U = rand(10, 1)
239-
g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U, graph_type=GRAPH_T)
240-
for _ in 1:num_graphs])
241-
242-
@test LearnBase.getobs(g, 3) == getgraph(g, 3)
243-
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
244-
@test StatsBase.nobs(g) == g.num_graphs
245-
246-
d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false)
247-
@test first(d) == getgraph(g, 1:2)
239+
data = [rand_graph(n, m, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T)
240+
for _ in 1:num_graphs]
241+
g = Flux.batch(data)
242+
243+
@testset "batch then pass to dataloader" begin
244+
@test LearnBase.getobs(g, 3) == getgraph(g, 3)
245+
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
246+
@test StatsBase.nobs(g) == g.num_graphs
247+
248+
d = Flux.Data.DataLoader(g, batchsize=2, shuffle=false)
249+
@test first(d) == getgraph(g, 1:2)
250+
end
251+
252+
@testset "pass to dataloader and collate" begin
253+
@test LearnBase.getobs(data, 3) == getgraph(g, 3)
254+
@test LearnBase.getobs(data, 3:5) == getgraph(g, 3:5)
255+
@test StatsBase.nobs(data) == g.num_graphs
256+
257+
d = Flux.Data.DataLoader(data, batchsize=2, shuffle=false)
258+
@test first(d) == getgraph(g, 1:2)
259+
end
248260
end
249261

250262
@testset "Graphs.jl integration" begin

test/layers/conv.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,12 @@
103103
end
104104

105105
@testset "GATConv" begin
106-
106+
107107
for heads in (1, 2), concat in (true, false)
108108
l = GATConv(in_channel => out_channel; heads, concat)
109109
for g in test_graphs
110110
test_layer(l, g, rtol=RTOL_LOW,
111+
exclude_grad_fields = [:negative_slope],
111112
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
112113
end
113114
end
@@ -116,7 +117,9 @@
116117
ein = 3
117118
l = GATConv((in_channel, ein) => out_channel, add_self_loops=false)
118119
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
119-
test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes))
120+
test_layer(l, g, rtol=RTOL_LOW,
121+
exclude_grad_fields = [:negative_slope],
122+
outsize=(out_channel, g.num_nodes))
120123
end
121124

122125
@testset "num params" begin
@@ -135,6 +138,7 @@
135138
l = GATv2Conv(in_channel => out_channel, tanh; heads, concat)
136139
for g in test_graphs
137140
test_layer(l, g, rtol=RTOL_LOW,
141+
exclude_grad_fields = [:negative_slope],
138142
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
139143
end
140144
end
@@ -143,7 +147,9 @@
143147
ein = 3
144148
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false)
145149
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
146-
test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes))
150+
test_layer(l, g, rtol=RTOL_LOW,
151+
exclude_grad_fields = [:negative_slope],
152+
outsize=(out_channel, g.num_nodes))
147153
end
148154

149155
@testset "num params" begin
@@ -159,7 +165,9 @@
159165
ein = 3
160166
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false)
161167
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
162-
test_layer(l, g, rtol=1e-3, outsize=(out_channel, g.num_nodes))
168+
test_layer(l, g, rtol=1e-3,
169+
exclude_grad_fields = [:negative_slope],
170+
outsize=(out_channel, g.num_nodes))
163171
end
164172
end
165173

0 commit comments

Comments
 (0)