Skip to content

Commit 20c591c

Browse files
don't automatically batch graphs inside vector when using getobs (#183)
* don't automatically batch graphs inside vector when using getobs * add tests * bump version * bump * nice error
1 parent afd80e8 commit 20c591c

File tree

10 files changed

+85
-61
lines changed

10 files changed

+85
-61
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.4.5"
4+
version = "0.5.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/src/index.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Usage examples on real datasets can be found in the [examples](https://github.co
2626
We create a dataset consisting in multiple random graphs and associated data features.
2727

2828
```julia
29-
using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics
29+
using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics, MLUtils
3030
using Flux.Data: DataLoader
3131

3232
all_graphs = GNNGraph[]
@@ -60,13 +60,17 @@ opt = Adam(1f-4)
6060
### Training
6161

6262
Finally, we use a standard Flux training pipeline to fit our dataset.
63-
Flux's `DataLoader` iterates over mini-batches of graphs
64-
(batched together into a `GNNGraph` object).
63+
We use Flux's `DataLoader` to iterate over mini-batches of graphs
64+
that are glued together into a single `GNNGraph` using the [`MLUtils.batch`](@ref) method. This is what happens under the hood when creating a `DataLoader` with the
65+
`collate=true` option.
6566

6667
```julia
67-
train_size = round(Int, 0.8 * length(all_graphs))
68-
train_loader = DataLoader(all_graphs[1:train_size], batchsize=32, shuffle=true)
69-
test_loader = DataLoader(all_graphs[train_size+1:end], batchsize=32, shuffle=false)
68+
train_graphs, test_graphs = MLUtils.split(all_graphs, at=0.8)
69+
70+
train_loader = DataLoader(train_graphs,
71+
batchsize=32, shuffle=true, collate=true)
72+
test_loader = DataLoader(test_graphs,
73+
batchsize=32, shuffle=false, collate=true)
7074

7175
loss(g::GNNGraph) = mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2)
7276

docs/src/tutorials/graph_classification_pluto.jl

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
### A Pluto.jl notebook ###
2-
# v0.19.5
2+
# v0.19.6
33

44
#> [frontmatter]
55
#> title = "Graph Classification with Graph Neural Networks"
@@ -13,12 +13,13 @@ using InteractiveUtils
1313
begin
1414
using Pkg
1515
Pkg.activate(; temp=true)
16-
packages = [
16+
Pkg.add([
1717
PackageSpec(; path=joinpath(@__DIR__,"..","..","..")),
1818
PackageSpec(; name="Flux", version="0.13"),
1919
PackageSpec(; name="MLDatasets", version="0.7"),
20-
]
21-
Pkg.add(packages)
20+
PackageSpec(; name="MLUtils"),
21+
])
22+
Pkg.develop("GraphNeuralNetworks")
2223
end
2324

2425
# ╔═╡ 361e0948-d91a-11ec-2d95-2db77435a0c1
@@ -29,6 +30,7 @@ begin
2930
using Flux.Data: DataLoader
3031
using GraphNeuralNetworks
3132
using MLDatasets
33+
using MLUtils
3234
using LinearAlgebra, Random, Statistics
3335
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
3436
Random.seed!(17) # for reproducibility
@@ -73,8 +75,11 @@ This dataset provides **188 different graphs**, and the task is to classify each
7375
By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**.
7476
It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes).
7577
However, for the sake of simplicity, we will not make use of edge labels.
78+
"""
7679

77-
We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing:
80+
# ╔═╡ 7f7750ff-b7fa-4fe2-a5a8-6c9c26c479bb
81+
md"""
82+
We now convert the MLDatasets.jl graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict):
7883
"""
7984

8085
# ╔═╡ 936c09f6-ee62-4bc2-a0c6-749a66080fd2
@@ -84,19 +89,27 @@ begin
8489
ndata=Float32.(onehotbatch(g.ndata.targets, 0:6)),
8590
edata=nothing)
8691
for g in graphs]
92+
y = onehotbatch(dataset.graph_data.targets, [-1, 1])
8793
end
8894

95+
# ╔═╡ 2c6ccfdd-cf11-415b-b398-95e5b0b2bbd4
96+
md"""We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing:
97+
"""
98+
8999
# ╔═╡ 519477b2-8323-4ece-a7eb-141e9841117c
100+
train_data, test_data = splitobs((graphs, y), at=150, shuffle=true) |> getobs
101+
102+
# ╔═╡ 3c3d5038-0ef6-47d7-a1b7-50880c5f3a0b
90103
begin
91-
shuffled_idxs = randperm(length(graphs))
92-
train_idxs = shuffled_idxs[1:150]
93-
test_idxs = shuffled_idxs[151:end]
94-
train_graphs = graphs[train_idxs]
95-
test_graphs = graphs[test_idxs]
96-
ytrain = onehotbatch(dataset.graph_data.targets[train_idxs], [-1, 1])
97-
ytest = onehotbatch(dataset.graph_data.targets[test_idxs], [-1, 1])
104+
train_loader = DataLoader(train_data, batchsize=64, shuffle=true)
105+
test_loader = DataLoader(test_data, batchsize=64, shuffle=false)
98106
end
99107

108+
# ╔═╡ f7778e2d-2e2a-4fc8-83b0-5242e4ec5eb4
109+
md"""
110+
Here, we opt for a `batch_size` of 64, leading to 3 (randomly shuffled) mini-batches, containing all ``2 \cdot 64+22 = 150`` graphs.
111+
"""
112+
100113
# ╔═╡ 2a1c501e-811b-4ddd-887b-91e8c929c8b7
101114
md"""
102115
## Mini-batching of graphs
@@ -114,35 +127,27 @@ This procedure has some crucial advantages over other batching procedures:
114127
115128
2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges.
116129
117-
GNN.jl can **batch multiple graphs into a single giant graph** with the help of Flux's `DataLoader`:
130+
GNN.jl can **batch multiple graphs into a single giant graph**:
118131
"""
119132

120133

121-
# ╔═╡ c202e3b7-1f39-496a-98e7-e03ada53b5c7
122-
begin
123-
train_loader = DataLoader((train_graphs, ytrain), batchsize=64, shuffle=true)
124-
test_loader = DataLoader((test_graphs, ytest), batchsize=64, shuffle=false)
125-
end
126-
127134
# ╔═╡ a142610a-d862-42a9-88af-c8d8b6825650
128-
first(train_loader)
135+
vec_gs, _ = first(train_loader)
129136

130137
# ╔═╡ 6faaf637-a0ff-468c-86b5-b0a7250258d6
131-
collect(train_loader)
132-
133-
# ╔═╡ 6cc5e766-ddcd-4547-b69c-6435428caf44
134-
first(train_loader)[1]
138+
MLUtils.batch(vec_gs)
135139

136-
# ╔═╡ ac69571a-998b-4630-afd6-f3d405618bc5
140+
# ╔═╡ e314b25f-e904-4c39-bf60-24cddf91fe9d
137141
md"""
138-
Here, we opt for a `batch_size` of 64, leading to 3 (randomly shuffled) mini-batches, containing all ``2 \cdot 64+22 = 150`` graphs.
139-
140-
Furthermore, each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch:
142+
Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch:
141143
142144
```math
143145
\textrm{graph-indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ]
144146
```
147+
"""
145148

149+
# ╔═╡ ac69571a-998b-4630-afd6-f3d405618bc5
150+
md"""
146151
## Training a Graph Neural Network (GNN)
147152
148153
Training a GNN for graph classification usually follows a simple recipe:
@@ -186,7 +191,7 @@ function eval_loss_accuracy(model, data_loader, device)
186191
acc = 0.
187192
ntot = 0
188193
for (g, y) in data_loader
189-
g, y = g |> device, y |> device
194+
g, y = MLUtils.batch(g) |> device, y |> device
190195
n = length(y)
191196
= model(g, g.ndata.x)
192197
loss += logitcrossentropy(ŷ, y) * n
@@ -214,7 +219,7 @@ function train!(model; epochs=200, η=1e-2, infotime=10)
214219
report(0)
215220
for epoch in 1:epochs
216221
for (g, y) in train_loader
217-
g, y = g |> device, y |> device
222+
g, y = MLUtils.batch(g) |> device, y |> device
218223
gs = Flux.gradient(ps) do
219224
= model(g, g.ndata.x)
220225
logitcrossentropy(ŷ, y)
@@ -266,22 +271,25 @@ You have learned how graphs can be batched together for better GPU utilization,
266271
"""
267272

268273
# ╔═╡ Cell order:
269-
# ╟─c97a0002-2253-45b6-9266-017189dbb6fe
274+
# ╠═c97a0002-2253-45b6-9266-017189dbb6fe
270275
# ╠═361e0948-d91a-11ec-2d95-2db77435a0c1
271276
# ╟─15136fd8-f9b2-4841-9a95-9de7b8969687
272277
# ╠═f6e86958-e96f-4c77-91fc-c72d8967575c
273278
# ╠═24f76360-8599-46c8-a49f-4c31f02eb7d8
274279
# ╠═5d5e5152-c860-4158-8bc7-67ee1022f9f8
275280
# ╠═33163dd2-cb35-45c7-ae5b-d4854d141773
276281
# ╠═a8d6a133-a828-4d51-83c4-fb44f9d5ede1
277-
# ╟─3b3e0a79-264b-47d7-8bda-2a6db7290828
282+
# ╠═3b3e0a79-264b-47d7-8bda-2a6db7290828
283+
# ╠═7f7750ff-b7fa-4fe2-a5a8-6c9c26c479bb
278284
# ╠═936c09f6-ee62-4bc2-a0c6-749a66080fd2
285+
# ╟─2c6ccfdd-cf11-415b-b398-95e5b0b2bbd4
279286
# ╠═519477b2-8323-4ece-a7eb-141e9841117c
287+
# ╠═3c3d5038-0ef6-47d7-a1b7-50880c5f3a0b
288+
# ╟─f7778e2d-2e2a-4fc8-83b0-5242e4ec5eb4
280289
# ╟─2a1c501e-811b-4ddd-887b-91e8c929c8b7
281-
# ╠═c202e3b7-1f39-496a-98e7-e03ada53b5c7
282290
# ╠═a142610a-d862-42a9-88af-c8d8b6825650
283291
# ╠═6faaf637-a0ff-468c-86b5-b0a7250258d6
284-
# ╠═6cc5e766-ddcd-4547-b69c-6435428caf44
292+
# ╟─e314b25f-e904-4c39-bf60-24cddf91fe9d
285293
# ╟─ac69571a-998b-4630-afd6-f3d405618bc5
286294
# ╠═04402032-18a4-42b5-ad04-19b286bd29b7
287295
# ╟─2313fd8d-6e84-4bde-bacc-fb697dc33cbb

examples/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
66
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
77
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
88
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
9+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
910
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1011
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1112

1213
[compat]
1314
DiffEqFlux = "1.45"
1415
Flux = "0.13"
16+
GraphNeuralNetworks = "0.5"
1517
Graphs = "1"
16-
GraphNeuralNetworks = "0.4"
1718
MLDatasets = "0.6, 0.7"
1819
julia = "1.7"

examples/graph_classification_tudataset.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ function eval_loss_accuracy(model, data_loader, device)
1515
loss = 0.
1616
acc = 0.
1717
ntot = 0
18-
for (graphs, y) in data_loader
19-
g = Flux.batch(graphs) |> device
20-
y = y |> device
18+
for (g, y) in data_loader
19+
g, y = (g, y) |> device
2120
n = length(y)
2221
= model(g, g.ndata.x) |> vec
2322
loss += logitbinarycrossentropy(ŷ, y) * n
@@ -66,10 +65,10 @@ function train(; kws...)
6665
NUM_TRAIN = 150
6766

6867
dataset = getdataset()
69-
train_data, test_data = splitobs(dataset, at=NUM_TRAIN/numobs(dataset), shuffle=true)
68+
train_data, test_data = splitobs(dataset, at=NUM_TRAIN, shuffle=true)
7069

71-
train_loader = DataLoader(train_data, batchsize=args.batchsize, shuffle=true)
72-
test_loader = DataLoader(test_data, batchsize=args.batchsize, shuffle=false)
70+
train_loader = DataLoader(train_data; args.batchsize, shuffle=true, collate=true)
71+
test_loader = DataLoader(test_data; args.batchsize, shuffle=false, collate=true)
7372

7473
# DEFINE MODEL
7574

@@ -96,9 +95,8 @@ function train(; kws...)
9695

9796
report(0)
9897
for epoch in 1:args.epochs
99-
for (graphs, y) in train_loader
100-
g = Flux.batch(graphs) |> device
101-
y = y |> device
98+
for (g, y) in train_loader
99+
g, y = (g, y) |> device
102100
gs = Flux.gradient(ps) do
103101
= model(g, g.ndata.x) |> vec
104102
logitbinarycrossentropy(ŷ, y)

src/GNNGraphs/gnngraph.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,6 @@ end
236236
MLUtils.numobs(g::GNNGraph) = g.num_graphs
237237
MLUtils.getobs(g::GNNGraph, i) = getgraph(g, i)
238238

239-
# DataLoader compatibility passing a vector of graphs and
240-
# effectively using `batch` as a collated function.
241-
MLUtils.numobs(data::Vector{<:GNNGraph}) = length(data)
242-
MLUtils.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
243-
MLUtils.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])
244-
245239

246240
#########################
247241

src/GNNGraphs/transform.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ function Flux.batch(gs::AbstractVector{<:GNNGraph{T}}) where T<:COO_T
432432
)
433433
end
434434

435+
Flux.batch(g::GNNGraph) =
436+
throw(ArgumentError("Cannot batch a `GNNGraph` (containing $(g.num_graphs) graphs). Pass a vector of `GNNGraph`s instead."))
437+
435438
"""
436439
unbatch(g::GNNGraph)
437440

src/layers/basic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ abstract type GNNLayer end
1111
# To be specialized by layers also needing edge features as input (e.g. NNConv).
1212
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g)))
1313

14+
function (l::GNNLayer)(g::AbstractVector{<:GNNGraph}, args...; kws...)
15+
@warn "Passing an array of graphs to a `GNNLayer` is discouraged.
16+
Explicitely call `MLUtils.batch(graphs)` first instead." maxlog=1
17+
return l(batch(g), args...; kws...)
18+
end
19+
1420

1521
"""
1622
WithGraph(model, g::GNNGraph; traingraph=false)

test/GNNGraphs/gnngraph.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,14 @@
269269
@test first(d) == getgraph(g, 1:2)
270270
end
271271

272-
@testset "pass to dataloader and collate" begin
273-
@test MLUtils.getobs(data, 3) == getgraph(g, 3)
274-
@test MLUtils.getobs(data, 3:5) == getgraph(g, 3:5)
272+
@testset "pass to dataloader and no automatic collation" begin
273+
@test MLUtils.getobs(data, 3) == data[3]
274+
@test MLUtils.getobs(data, 3:5) isa Vector{<:GNNGraph}
275+
@test MLUtils.getobs(data, 3:5) == [data[3], data[4], data[5]]
275276
@test MLUtils.numobs(data) == g.num_graphs
276277

277278
d = Flux.Data.DataLoader(data, batchsize=2, shuffle=false)
278-
@test first(d) == getgraph(g, 1:2)
279+
@test first(d) == [data[1], data[2]]
279280
end
280281
end
281282

test/layers/basic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,14 @@
9393
params, restructure = Flux.destructure(chain)
9494
@test restructure(params) isa GNNChain
9595
end
96+
@testset "GNNGraph array input" begin
97+
gs = [rand_graph(5, 6, ndata=rand(2, 5), graph_type=GRAPH_T) for _ in 1:4]
98+
l = GCNConv(2 => 3)
99+
y = l(gs, rand(2, 20))
100+
@test size(y) == (3, 20)
101+
102+
gout = l(gs)
103+
@test size(gout.ndata.x) == (3, 20)
104+
end
96105
end
97106

0 commit comments

Comments
 (0)