Skip to content

Commit 5eef31d

Browse files
committed
Add node_classification literate src
1 parent 8298b8d commit 5eef31d

File tree

2 files changed

+275
-2472
lines changed

2 files changed

+275
-2472
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# # Node Classification with Graph Neural Networks
2+
3+
# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning).
4+
5+
# ## Import
6+
# Let us start off by importing some libraries. We will be using `Flux.jl` and `GraphNeuralNetworks.jl` for our tutorial.
7+
8+
using Flux, GraphNeuralNetworks
9+
using Flux: onecold, onehotbatch, logitcrossentropy
10+
using MLDatasets
11+
using Plots, TSne
12+
using Statistics, Random
13+
14+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
15+
Random.seed!(17); # for reproducibility
16+
17+
# ## Visualize
18+
# We want to visualize our results using t-distributed stochastic neighbor embedding (tsne) to project our output onto a 2D plane.
19+
20+
function visualize_tsne(out, targets)
21+
z = tsne(out, 2)
22+
scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false)
23+
end;
24+
25+
# ## Dataset: Cora
26+
27+
# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents categorized into seven classes with 5,429 citation links. Each node represents an article or document, and edges between nodes indicate a citation relationship, where one cites the other.
28+
29+
# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.
30+
31+
# This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset.
32+
33+
dataset = Cora()
34+
35+
# Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself.
36+
37+
dataset.metadata
38+
39+
# The `graphs` variable contains the graph. The `Cora` dataset contains only 1 graph.
40+
41+
42+
dataset.graphs
43+
44+
# There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`.
45+
46+
g = mldataset2gnngraph(dataset)
47+
48+
println("Number of nodes: $(g.num_nodes)")
49+
println("Number of edges: $(g.num_edges)")
50+
println("Average node degree: $(g.num_edges / g.num_nodes)")
51+
println("Number of training nodes: $(sum(g.ndata.train_mask))")
52+
println("Training node label rate: $(mean(g.ndata.train_mask))")
53+
println("Has isolated nodes: $(has_isolated_nodes(g))")
54+
println("Has self-loops: $(has_self_loops(g))")
55+
println("Is undirected: $(is_bidirected(g))")
56+
57+
58+
# Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network.
59+
# We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9.
60+
# For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class).
61+
# This results in a training node label rate of only 5%.
62+
63+
# We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation).
64+
65+
x = g.ndata.features # we onehot encode both the node labels (what we want to predict):
66+
y = onehotbatch(g.ndata.targets, 1:7)
67+
train_mask = g.ndata.train_mask
68+
num_features = size(x)[1]
69+
hidden_channels = 16
70+
num_classes = dataset.metadata["num_classes"];
71+
72+
# ## Multi-layer Perception Network (MLP)
73+
74+
# In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account.
75+
76+
# Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes):
77+
78+
struct MLP
79+
layers::NamedTuple
80+
end
81+
82+
Flux.@layer :expand MLP
83+
84+
function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5)
85+
layers = (hidden = Dense(num_features => hidden_channels),
86+
drop = Dropout(drop_rate),
87+
classifier = Dense(hidden_channels => num_classes))
88+
return MLP(layers)
89+
end;
90+
91+
function (model::MLP)(x::AbstractMatrix)
92+
l = model.layers
93+
x = l.hidden(x)
94+
x = relu(x)
95+
x = l.drop(x)
96+
x = l.classifier(x)
97+
return x
98+
end;
99+
100+
# ### Training a Multilayer Perceptron
101+
102+
# Our MLP is defined by two linear layers and enhanced by [ReLU](https://fluxml.ai/Flux.jl/stable/models/nnlib/#NNlib.relu) non-linearity and [Dropout](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Dropout).
103+
# Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes.
104+
105+
# Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GraphNeuralNetworks.jl/stable/tutorials/gnn_intro/).
106+
# We again make use of the **cross entropy loss** and **Adam optimizer**.
107+
# This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training).
108+
109+
function train(model::MLP, data::AbstractMatrix, epochs::Int, opt)
110+
Flux.trainmode!(model)
111+
112+
for epoch in 1:epochs
113+
loss, grad = Flux.withgradient(model) do model
114+
= model(data)
115+
logitcrossentropy(ŷ[:, train_mask], y[:, train_mask])
116+
end
117+
118+
Flux.update!(opt, model, grad[1])
119+
if epoch % 200 == 0
120+
@show epoch, loss
121+
end
122+
end
123+
end;
124+
125+
function accuracy(model::MLP, x::AbstractMatrix, y::Flux.OneHotArray, mask::BitVector)
126+
Flux.testmode!(model)
127+
mean(onecold(model(x))[mask] .== onecold(y)[mask])
128+
end;
129+
130+
mlp = MLP(num_features, num_classes, hidden_channels)
131+
opt_mlp = Flux.setup(Adam(1e-3), mlp)
132+
epochs = 2000
133+
train(mlp, g.ndata.features, epochs, opt_mlp)
134+
135+
# After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels.
136+
# Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes:
137+
138+
accuracy(mlp, g.ndata.features, y, .!train_mask)
139+
140+
141+
# As one can see, our MLP performs rather bad with only about ~50% test accuracy.
142+
# But why does the MLP do not perform better?
143+
# The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations.
144+
145+
# It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**.
146+
# That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model.
147+
148+
149+
150+
# ## Training a Graph Convolutional Neural Network (GNN)
151+
152+
# Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GraphNeuralNetworks.jl/stable/api/conv/#GraphNeuralNetworks.GCNConv) module.
153+
# To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as
154+
155+
# ```math
156+
# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)}
157+
# ```
158+
159+
# where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge.
160+
# In contrast, a single `Linear` layer is defined as
161+
162+
# ```math
163+
# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)}
164+
# ```
165+
166+
# which does not make use of neighboring node information.
167+
168+
struct GCN
169+
layers::NamedTuple
170+
end
171+
172+
Flux.@layer GCN # provides parameter collection, gpu movement and more
173+
174+
function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5)
175+
layers = (conv1 = GCNConv(num_features => hidden_channels),
176+
drop = Dropout(drop_rate),
177+
conv2 = GCNConv(hidden_channels => num_classes))
178+
return GCN(layers)
179+
end;
180+
181+
function (gcn::GCN)(g::GNNGraph, x::AbstractMatrix)
182+
l = gcn.layers
183+
x = l.conv1(g, x)
184+
x = relu.(x)
185+
x = l.drop(x)
186+
x = l.conv2(g, x)
187+
return x
188+
end;
189+
190+
191+
# Now let's visualize the node embeddings of our **untrained** GCN network.
192+
193+
gcn = GCN(num_features, num_classes, hidden_channels)
194+
h_untrained = gcn(g, x) |> transpose
195+
visualize_tsne(h_untrained, g.ndata.targets)
196+
197+
198+
# We certainly can do better by training our model.
199+
# The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model.
200+
201+
function train(model::GCN, g::GNNGraph, x::AbstractMatrix, epochs::Int, opt)
202+
Flux.trainmode!(model)
203+
204+
for epoch in 1:epochs
205+
loss, grad = Flux.withgradient(model) do model
206+
= model(g, x)
207+
logitcrossentropy(ŷ[:, train_mask], y[:, train_mask])
208+
end
209+
210+
Flux.update!(opt, model, grad[1])
211+
if epoch % 200 == 0
212+
@show epoch, loss
213+
end
214+
end
215+
end;
216+
217+
#
218+
219+
mlp = MLP(num_features, num_classes, hidden_channels)
220+
opt_mlp = Flux.setup(Adam(1e-3), mlp)
221+
epochs = 2000
222+
train(mlp, g.ndata.features, epochs, opt_mlp)
223+
224+
#
225+
function accuracy(model::GCN, g::GNNGraph, x::AbstractMatrix, y::Flux.OneHotArray,
226+
mask::BitVector)
227+
Flux.testmode!(model)
228+
mean(onecold(model(g, x))[mask] .== onecold(y)[mask])
229+
end
230+
231+
#
232+
233+
accuracy(mlp, g.ndata.features, y, .!train_mask)
234+
235+
#
236+
237+
opt_gcn = Flux.setup(Adam(1e-2), gcn)
238+
train(gcn, g, x, epochs, opt_gcn)
239+
240+
241+
# Now let's evaluate the loss of our trained GCN.
242+
243+
train_accuracy = accuracy(gcn, g, g.ndata.features, y, train_mask)
244+
test_accuracy = accuracy(gcn, g, g.ndata.features, y, .!train_mask)
245+
246+
println("Train accuracy: $(train_accuracy)")
247+
println("Test accuracy: $(test_accuracy)")
248+
249+
250+
# **There it is!**
251+
# By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**!
252+
# This is in stark contrast to the 59% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance.
253+
254+
# We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category.
255+
256+
257+
Flux.testmode!(gcn) # inference mode
258+
259+
out_trained = gcn(g, x) |> transpose
260+
visualize_tsne(out_trained, g.ndata.targets)
261+
262+
263+
264+
# ## (Optional) Exercises
265+
266+
# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**.
267+
268+
# 2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all?
269+
270+
# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GraphNeuralNetworks.jl/stable/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head.
271+
272+
273+
274+
# ## Conclusion
275+
# In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification.

0 commit comments

Comments
 (0)