Skip to content

Commit 1687abe

Browse files
Add Graph Classification tutorial (#568)
* Add GlobalPool * Add graph classification tutorial * Add `GlobalPool` pooling docs * Fix ref * Add `GlobalPool` test * Fix Co-authored-by: Carlo Lucibello <[email protected]> * Fix text Co-authored-by: Carlo Lucibello <[email protected]> * Add changes to the src file * Fix pooling layer --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 4bb9b96 commit 1687abe

File tree

8 files changed

+588
-1
lines changed

8 files changed

+588
-1
lines changed

GNNLux/docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ makedocs(;
6262
"Introductory tutorials" => [
6363
"Hands on" => "tutorials/gnn_intro.md",
6464
"Node Classification" => "tutorials/node_classification.md",
65+
"Graph Classification" => "tutorials/graph_classification.md",
6566
],
6667
],
6768

GNNLux/docs/make_tutorials.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ using Literate
22

33
Literate.markdown("src_tutorials/gnn_intro.jl", "src/tutorials/"; execute = true)
44

5-
Literate.markdown("src_tutorials/node_classification.jl", "src/tutorials/"; execute = true)
5+
Literate.markdown("src_tutorials/graph_classification.jl", "src/tutorials/"; execute = true)
6+
7+
Literate.markdown("src_tutorials/node_classification.jl", "src/tutorials/"; execute = true)

GNNLux/docs/src/api/pool.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
```@meta
2+
CurrentModule = GNNLux
3+
CollapsedDocStrings = true
4+
```
5+
6+
# Pooling Layers
7+
8+
## Index
9+
10+
```@index
11+
Order = [:type, :function]
12+
Pages = ["pool.md"]
13+
```
14+
15+
```@autodocs
16+
Modules = [GNNLux]
17+
Pages = ["layers/pool.jl"]
18+
Private = false
19+
```
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# Graph Classification with Graph Neural Networks
2+
3+
*This tutorial is a Julia adaptation of the Pytorch Geometric tutorial that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).*
4+
5+
In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**.
6+
Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties and possibly on some input node features.
7+
Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand.
8+
9+
A common graph classification task is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not.
10+
11+
The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl.
12+
Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**:
13+
14+
````julia
15+
using Lux, GNNLux
16+
using MLDatasets, MLUtils
17+
using LinearAlgebra, Random, Statistics
18+
using Zygote, Optimisers, OneHotArrays
19+
20+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
21+
rng = Random.seed!(42); # for reproducibility
22+
23+
dataset = TUDataset("MUTAG")
24+
````
25+
26+
````
27+
dataset TUDataset:
28+
name => MUTAG
29+
metadata => Dict{String, Any} with 1 entry
30+
graphs => 188-element Vector{MLDatasets.Graph}
31+
graph_data => (targets = "188-element Vector{Int64}",)
32+
num_nodes => 3371
33+
num_edges => 7442
34+
num_graphs => 188
35+
````
36+
37+
````julia
38+
dataset.graph_data.targets |> union
39+
````
40+
41+
````
42+
2-element Vector{Int64}:
43+
1
44+
-1
45+
````
46+
47+
````julia
48+
g1, y1 = dataset[1] # get the first graph and target
49+
````
50+
51+
````
52+
(graphs = Graph(17, 38), targets = 1)
53+
````
54+
55+
````julia
56+
reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union
57+
````
58+
59+
````
60+
7-element Vector{Int64}:
61+
0
62+
1
63+
2
64+
3
65+
4
66+
5
67+
6
68+
````
69+
70+
````julia
71+
reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union
72+
````
73+
74+
````
75+
4-element Vector{Int64}:
76+
0
77+
1
78+
2
79+
3
80+
````
81+
82+
This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**.
83+
84+
By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**.
85+
It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes).
86+
However, for the sake of simplicity, we will not make use of edge labels.
87+
88+
We now convert the `MLDatasets.jl` graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict):
89+
90+
````julia
91+
graphs = mldataset2gnngraph(dataset)
92+
graphs = [GNNGraph(g,
93+
ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)),
94+
edata = nothing)
95+
for g in graphs]
96+
y = onehotbatch(dataset.graph_data.targets, [-1, 1])
97+
````
98+
99+
````
100+
2×188 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
101+
⋅ 1 1 ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ 1 1 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ 1 1 ⋅ 1 ⋅ ⋅ 1 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 ⋅ 1 ⋅ ⋅ 1 1 ⋅ 1 1 1 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 1 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ 1 1 ⋅ ⋅ 1 1 ⋅ 1
102+
1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ 1 1 1 1 ⋅ 1 1 ⋅ 1 ⋅ 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ 1 1 1 1 1 1 1 1 1 1 1 1 ⋅ 1 1 1 1 1 1 ⋅ 1 1 ⋅ ⋅ 1 1 1 ⋅ 1 1 ⋅ 1 1 ⋅ ⋅ ⋅ 1 1 1 1 1 ⋅ 1 1 1 ⋅ ⋅ 1 1 1 1 1 1 1 1 ⋅ 1 ⋅ 1 1 1 1 1 1 1 1 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 1 ⋅ ⋅ 1 1 ⋅ ⋅ 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ 1 1 ⋅ 1 1 1 ⋅ ⋅ ⋅ 1 1 1 ⋅ 1 1 1 1 1 1 1 ⋅ 1 1 1 1 1 1 ⋅ 1 1 1 ⋅ 1 ⋅ ⋅ 1 1 ⋅ ⋅ 1 ⋅
103+
````
104+
105+
We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing:
106+
107+
````julia
108+
train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs
109+
110+
111+
train_loader = DataLoader(train_data, batchsize = 32, shuffle = true)
112+
test_loader = DataLoader(test_data, batchsize = 32, shuffle = false)
113+
````
114+
115+
````
116+
2-element DataLoader(::Tuple{Vector{GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32)
117+
with first element:
118+
(32-element Vector{GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, 2×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)
119+
````
120+
121+
Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all $4 \cdot 32+22 = 150$ graphs.
122+
123+
## Mini-batching of graphs
124+
125+
Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization.
126+
In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension.
127+
The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`.
128+
129+
However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption.
130+
Therefore, GNNLux.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension).
131+
132+
This procedure has some crucial advantages over other batching procedures:
133+
134+
1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs.
135+
136+
2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges.
137+
138+
GNNLux.jl can **batch multiple graphs into a single giant graph**:
139+
140+
````julia
141+
vec_gs, _ = first(train_loader)
142+
````
143+
144+
````
145+
(GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(11, 22) with x: 7×11 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(25, 56) with x: 7×25 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(28, 66) with x: 7×28 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(12, 24) with x: 7×12 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(27, 66) with x: 7×27 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(24, 50) with x: 7×24 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(16, 36) with x: 7×16 data], Bool[1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 1 1 1 1 1 0; 0 1 1 1 1 1 1 1 0 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 0 0 0 0 0 1])
146+
````
147+
148+
````julia
149+
MLUtils.batch(vec_gs)
150+
````
151+
152+
````
153+
GNNGraph:
154+
num_nodes: 570
155+
num_edges: 1254
156+
num_graphs: 32
157+
ndata:
158+
x = 7×570 Matrix{Float32}
159+
````
160+
161+
Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch:
162+
163+
```math
164+
\textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ]
165+
```
166+
167+
## Training a Graph Neural Network (GNN)
168+
169+
Training a GNN for graph classification usually follows a simple recipe:
170+
171+
1. Embed each node by performing multiple rounds of message passing
172+
2. Aggregate node embeddings into a unified graph embedding (**readout layer**)
173+
3. Train a final classifier on the graph embedding
174+
175+
There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings:
176+
177+
```math
178+
\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v
179+
```
180+
181+
GNNLux.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`.
182+
183+
The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:
184+
185+
````julia
186+
function create_model(nin, nh, nout)
187+
GNNChain(GCNConv(nin => nh, relu),
188+
GCNConv(nh => nh, relu),
189+
GCNConv(nh => nh),
190+
GlobalPool(mean),
191+
Dropout(0.5),
192+
Dense(nh, nout))
193+
end;
194+
195+
nin = 7
196+
nh = 64
197+
nout = 2
198+
model = create_model(nin, nh, nout)
199+
200+
ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model);
201+
````
202+
203+
````
204+
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
205+
└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18
206+
207+
````
208+
209+
Here, we again make use of the `GCNConv` with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer.
210+
211+
Let's train our network for a few epochs to see how well it performs on the training as well as test set:
212+
213+
````julia
214+
function custom_loss(model, ps, st, tuple)
215+
g, x, y = tuple
216+
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
217+
st = Lux.trainmode(st)
218+
ŷ, st = model(g, x, ps, st)
219+
return logitcrossentropy(ŷ, y), (; layers = st), 0
220+
end
221+
222+
function eval_loss_accuracy(model, ps, st, data_loader)
223+
loss = 0.0
224+
acc = 0.0
225+
ntot = 0
226+
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
227+
for (g, y) in data_loader
228+
g = MLUtils.batch(g)
229+
n = length(y)
230+
ŷ, _ = model(g, g.ndata.x, ps, st)
231+
loss += logitcrossentropy(ŷ, y) * n
232+
acc += mean((ŷ .> 0) .== y) * n
233+
ntot += n
234+
end
235+
return (loss = round(loss / ntot, digits = 4),
236+
acc = round(acc * 100 / ntot, digits = 2))
237+
end
238+
239+
function train_model!(model, ps, st; epochs = 500, infotime = 100)
240+
train_state = Lux.Training.TrainState(model, ps, st, Adam(1e-2))
241+
242+
function report(epoch)
243+
train = eval_loss_accuracy(model, ps, st, train_loader)
244+
st = Lux.testmode(st)
245+
test = eval_loss_accuracy(model, ps, st, test_loader)
246+
st = Lux.trainmode(st)
247+
@info (; epoch, train, test)
248+
end
249+
report(0)
250+
for iter in 1:epochs
251+
for (g, y) in train_loader
252+
g = MLUtils.batch(g)
253+
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.ndata.x, y), train_state)
254+
end
255+
256+
iter % infotime == 0 && report(iter)
257+
end
258+
return model, ps, st
259+
end
260+
261+
model, ps, st = train_model!(model, ps, st);
262+
````
263+
264+
````
265+
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
266+
└ @ LuxLib.Utils ~/.julia/packages/LuxLib/ru5RQ/src/utils.jl:314
267+
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
268+
└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18
269+
[ Info: (epoch = 0, train = (loss = 0.6934, acc = 51.67), test = (loss = 0.6902, acc = 50.0))
270+
[ Info: (epoch = 100, train = (loss = 0.3979, acc = 81.33), test = (loss = 0.5769, acc = 69.74))
271+
[ Info: (epoch = 200, train = (loss = 0.3904, acc = 84.0), test = (loss = 0.6402, acc = 65.79))
272+
[ Info: (epoch = 300, train = (loss = 0.3813, acc = 85.33), test = (loss = 0.6331, acc = 69.74))
273+
[ Info: (epoch = 400, train = (loss = 0.3682, acc = 85.0), test = (loss = 0.7273, acc = 69.74))
274+
[ Info: (epoch = 500, train = (loss = 0.3561, acc = 86.67), test = (loss = 0.6825, acc = 73.68))
275+
276+
````
277+
278+
As one can see, our model reaches around **74% test accuracy**.
279+
Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets.
280+
281+
## (Optional) Exercise
282+
283+
Can we do better than this?
284+
As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**.
285+
An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information:
286+
287+
```math
288+
\mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)}
289+
```
290+
291+
This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl.
292+
293+
As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`.
294+
This should bring you close to **82% test accuracy**.
295+
296+
## Conclusion
297+
298+
In this chapter, you have learned how to apply GNNs to the task of graph classification.
299+
You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings.
300+
301+
---
302+
303+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
304+

0 commit comments

Comments
 (0)