Skip to content

Commit d7e036f

Browse files
committed
Add graph classification literate and some fixes
1 parent 2f5a00a commit d7e036f

File tree

8 files changed

+521
-2508
lines changed

8 files changed

+521
-2508
lines changed

GraphNeuralNetworks/docs/make.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ makedocs(;
6060

6161
"Tutorials" => [
6262
"Introductory tutorials" => [
63-
"Hands on" => "tutorials/gnn_intro_pluto.md",
64-
"Node classification" => "tutorials/node_classification_pluto.md",
65-
"Graph classification" => "tutorials/graph_classification_pluto.md"
63+
"Hands on" => "tutorials/gnn_intro.md",
64+
"Node classification" => "tutorials/node_classification.md",
65+
"Graph classification" => "tutorials/graph_classification.md"
6666
],
6767
"Temporal graph neural networks" =>[
6868
"Node autoregression" => "tutorials/traffic_prediction.md",

GraphNeuralNetworks/docs/make_tutorials_literate.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,8 @@ Literate.markdown("src_tutorials/introductory_tutorials/gnn_intro.jl",
88
Literate.markdown("src_tutorials/introductory_tutorials/node_classification.jl",
99
"src/tutorials/"; execute = true)
1010

11+
Literate.markdown("src_tutorials/introductory_tutorials/graph_classification.jl",
12+
"src/tutorials/"; execute = true)
13+
1114
Literate.markdown("src_tutorials/introductory_tutorials/temporal_graph_classification.jl",
1215
"src/tutorials/"; execute = true)

GraphNeuralNetworks/docs/src/tutorials/gnn_intro_pluto.md

Lines changed: 0 additions & 261 deletions
This file was deleted.

GraphNeuralNetworks/docs/src/tutorials/graph_classification.md

Lines changed: 310 additions & 0 deletions
Large diffs are not rendered by default.

GraphNeuralNetworks/docs/src/tutorials/graph_classification_pluto.md

Lines changed: 0 additions & 224 deletions
This file was deleted.

GraphNeuralNetworks/docs/src/tutorials/node_classification_pluto.md

Lines changed: 0 additions & 337 deletions
This file was deleted.
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# # Graph Classification with Graph Neural Networks
2+
3+
# *This tutorial is a julia adaptation of the Pytorch Geometric tutorials that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).*
4+
5+
# In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**.
6+
# Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties and possibly on some input node features.
7+
# Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand.
8+
# We will use a graph convolutional network to create a vector embedding of the input graph, and the apply a simple linear classification head to perform the final classification.
9+
10+
# A common graph classification task is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not.
11+
12+
# The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl.
13+
# Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**:
14+
15+
16+
using Flux, GraphNeuralNetworks
17+
using Flux: onecold, onehotbatch, logitcrossentropy, DataLoader
18+
using MLDatasets, MLUtils
19+
using LinearAlgebra, Random, Statistics
20+
21+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
22+
Random.seed!(42); # for reproducibility
23+
#
24+
25+
dataset = TUDataset("MUTAG")
26+
27+
#
28+
dataset.graph_data.targets |> union
29+
30+
#
31+
g1, y1 = dataset[1] # get the first graph and target
32+
33+
#
34+
reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union
35+
36+
#
37+
reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union
38+
39+
# This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**.
40+
41+
# By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**.
42+
# It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes).
43+
# However, for the sake of simplicity, we will not make use of edge labels.
44+
45+
# We now convert the `MLDatasets.jl` graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict):
46+
47+
graphs = mldataset2gnngraph(dataset)
48+
graphs = [GNNGraph(g,
49+
ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)),
50+
edata = nothing)
51+
for g in graphs]
52+
y = onehotbatch(dataset.graph_data.targets, [-1, 1])
53+
54+
55+
# We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing:
56+
57+
train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs
58+
59+
60+
train_loader = DataLoader(train_data, batchsize = 32, shuffle = true)
61+
test_loader = DataLoader(test_data, batchsize = 32, shuffle = false)
62+
63+
# Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all $4 \cdot 32+22 = 150$ graphs.
64+
65+
66+
# ## Mini-batching of graphs
67+
68+
# Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization.
69+
# In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension.
70+
# The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`.
71+
72+
73+
# However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption.
74+
# Therefore, GraphNeuralNetworks.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension).
75+
76+
# This procedure has some crucial advantages over other batching procedures:
77+
78+
# 1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs.
79+
80+
# 2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges.
81+
82+
# GraphNeuralNetworks.jl can **batch multiple graphs into a single giant graph**:
83+
84+
85+
vec_gs, _ = first(train_loader)
86+
87+
#
88+
MLUtils.batch(vec_gs)
89+
90+
91+
# Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch:
92+
93+
# ```math
94+
# \textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ]
95+
# ```
96+
97+
98+
# ## Training a Graph Neural Network (GNN)
99+
100+
# Training a GNN for graph classification usually follows a simple recipe:
101+
102+
# 1. Embed each node by performing multiple rounds of message passing
103+
# 2. Aggregate node embeddings into a unified graph embedding (**readout layer**)
104+
# 3. Train a final classifier on the graph embedding
105+
106+
# There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings:
107+
108+
# ```math
109+
# \mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v
110+
# ```
111+
112+
# GraphNeuralNetworks.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`.
113+
114+
# The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:
115+
116+
function create_model(nin, nh, nout)
117+
GNNChain(GCNConv(nin => nh, relu),
118+
GCNConv(nh => nh, relu),
119+
GCNConv(nh => nh),
120+
GlobalPool(mean),
121+
Dropout(0.5),
122+
Dense(nh, nout))
123+
end;
124+
125+
126+
# Here, we again make use of the `GCNConv` with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer.
127+
128+
# Let's train our network for a few epochs to see how well it performs on the training as well as test set:
129+
130+
131+
132+
function eval_loss_accuracy(model, data_loader, device)
133+
loss = 0.0
134+
acc = 0.0
135+
ntot = 0
136+
for (g, y) in data_loader
137+
g, y = MLUtils.batch(g) |> device, y |> device
138+
n = length(y)
139+
= model(g, g.ndata.x)
140+
loss += logitcrossentropy(ŷ, y) * n
141+
acc += mean((ŷ .> 0) .== y) * n
142+
ntot += n
143+
end
144+
return (loss = round(loss / ntot, digits = 4),
145+
acc = round(acc * 100 / ntot, digits = 2))
146+
end
147+
148+
149+
function train!(model; epochs = 200, η = 1e-3, infotime = 10)
150+
## device = Flux.gpu # uncomment this for GPU training
151+
device = Flux.cpu
152+
model = model |> device
153+
opt = Flux.setup(Adam(η), model)
154+
155+
function report(epoch)
156+
train = eval_loss_accuracy(model, train_loader, device)
157+
test = eval_loss_accuracy(model, test_loader, device)
158+
@info (; epoch, train, test)
159+
end
160+
161+
report(0)
162+
for epoch in 1:epochs
163+
for (g, y) in train_loader
164+
g, y = MLUtils.batch(g) |> device, y |> device
165+
grad = Flux.gradient(model) do model
166+
= model(g, g.ndata.x)
167+
logitcrossentropy(ŷ, y)
168+
end
169+
Flux.update!(opt, model, grad[1])
170+
end
171+
epoch % infotime == 0 && report(epoch)
172+
end
173+
end
174+
175+
176+
nin = 7
177+
nh = 64
178+
nout = 2
179+
model = create_model(nin, nh, nout)
180+
train!(model)
181+
182+
183+
184+
# As one can see, our model reaches around **75% test accuracy**.
185+
# Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets.
186+
187+
# ## (Optional) Exercise
188+
189+
# Can we do better than this?
190+
# As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**.
191+
# An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information:
192+
193+
# ```math
194+
# \mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)}
195+
# ```
196+
197+
# This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl.
198+
199+
# As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`.
200+
# This should bring you close to **82% test accuracy**.
201+
202+
# ## Conclusion
203+
204+
# In this chapter, you have learned how to apply GNNs to the task of graph classification.
205+
# You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings.

0 commit comments

Comments
 (0)