Skip to content

Commit 213c640

Browse files
committed
First draft node_classification tutorial
1 parent 7e07564 commit 213c640

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# # Node Classification with Graph Neural Networks
2+
3+
# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning).
4+
5+
6+
# ## Import
7+
# Let us start off by importing some libraries. We will be using `Lux.jl` and `GNNLux.jl` for our tutorial.
8+
9+
using Lux, GNNLux
10+
using MLDatasets
11+
using Plots, TSne
12+
using Random, Statistics
13+
using Zygote, Optimisers, OneHotArrays
14+
15+
16+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
17+
rng = Random.seed!(17) # for reproducibility
18+
19+
# ## Visualize
20+
# We want to visualize the outputs of the results using t-distributed stochastic neighbor embedding (tsne) to embed our output embeddings onto a 2D plane.
21+
22+
function visualize_tsne(out, targets)
23+
z = tsne(out, 2)
24+
scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false)
25+
end;
26+
27+
28+
# ## Dataset: Cora
29+
30+
# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents classified into one of seven classes and 5429 links. Each node represents articles/documents and the edges between them when they cite each other.
31+
32+
# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.
33+
34+
# This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset.
35+
36+
dataset = Cora()
37+
38+
# Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself.
39+
40+
dataset.metadata
41+
42+
# The `graphs` variable GraphDataset contains the graph. The `Cora` dataset contains only 1 graph.
43+
44+
dataset.graphs
45+
46+
47+
# There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`.
48+
49+
g = mldataset2gnngraph(dataset)
50+
51+
52+
println("Number of nodes: $(g.num_nodes)")
53+
println("Number of edges: $(g.num_edges)")
54+
println("Average node degree: $(g.num_edges / g.num_nodes)")
55+
println("Number of training nodes: $(sum(g.ndata.train_mask))")
56+
println("Training node label rate: $(mean(g.ndata.train_mask))")
57+
println("Has isolated nodes: $(has_isolated_nodes(g))")
58+
println("Has self-loops: $(has_self_loops(g))")
59+
println("Is undirected: $(is_bidirected(g))")
60+
61+
62+
63+
# Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network.
64+
# We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9.
65+
# For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class).
66+
# This results in a training node label rate of only 5%.
67+
68+
# We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation).
69+
70+
x = g.ndata.features # we onehot encode both the node labels (what we want to predict):
71+
y = onehotbatch(g.ndata.targets, 1:7)
72+
train_mask = g.ndata.train_mask;
73+
num_features = size(x)[1];
74+
hidden_channels = 16;
75+
drop_rate = 0.5;
76+
num_classes = dataset.metadata["num_classes"];
77+
78+
79+
# ## Multi-layer Perception Network (MLP)
80+
81+
# In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account.
82+
83+
# Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes):
84+
85+
MLP = Chain(Dense(num_features => hidden_channels, relu),
86+
Dropout(drop_rate),
87+
Dense(hidden_channels => num_classes))
88+
89+
ps, st = Lux.setup(rng, MLP)
90+
91+
# ### Training a Multilayer Perceptron
92+
93+
# Our MLP is defined by two linear layers and enhanced by [ReLU](https://lux.csail.mit.edu/stable/api/NN_Primitives/ActivationFunctions#NNlib.relu) non-linearity and [Dropout](https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Dropout).
94+
# Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes.
95+
96+
# Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/tutorials/gnn_intro/).
97+
# We again make use of the **cross entropy loss** and **Adam optimizer**.
98+
# This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training).
99+
100+
101+
function custom_loss(model, ps, st, x)
102+
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
103+
ŷ, st = model(x, ps, st)
104+
return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0
105+
end
106+
107+
function train_model!(MLP, ps, st, x, epochs)
108+
train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3))
109+
for iter in 1:epochs
110+
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, x, train_state)
111+
112+
if iter % 100 == 0
113+
println("Epoch: $(iter) Loss: $(loss)")
114+
end
115+
end
116+
end
117+
118+
function accuracy(model, x, ps, st, y, mask)
119+
st = Lux.testmode(st)
120+
ŷ, st = model(x, ps, st)
121+
mean(onecold(ŷ)[mask] .== onecold(y)[mask])
122+
end
123+
124+
train_model!(MLP, ps, st, x, 2000)
125+
126+
127+
128+
# After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels.
129+
# Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes:
130+
131+
accuracy(MLP, x, ps, st, y, .!train_mask)
132+
133+
# As one can see, our MLP performs rather bad with only about ~50% test accuracy.
134+
# But why does the MLP do not perform better?
135+
# The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations.
136+
137+
# It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**.
138+
# That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model.
139+
140+
141+
# ## Training a Graph Convolutional Neural Network (GNN)
142+
143+
# Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GCNConv) module.
144+
# To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as
145+
146+
# ```math
147+
# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)}
148+
# ```
149+
150+
# where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge.
151+
# In contrast, a single `Linear` layer is defined as
152+
153+
# ```math
154+
# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)}
155+
# ```
156+
157+
# which does not make use of neighboring node information.
158+
159+
Lux.@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)}
160+
nf::Int
161+
nc::Int
162+
hd::Int
163+
conv1
164+
conv2
165+
drop
166+
use_bias::Bool
167+
init_weight
168+
init_bias
169+
end
170+
171+
function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor
172+
conv1 = GCNConv(num_features => hidden_channels)
173+
conv2 = GCNConv(hidden_channels => num_classes)
174+
drop = Dropout(drop_rate)
175+
return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias)
176+
end
177+
178+
function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass
179+
x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1)
180+
x = relu.(x)
181+
x, stdrop = gcn.drop(x, ps.drop, st.drop)
182+
x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2)
183+
return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2)
184+
end
185+
186+
187+
# function LuxCore.initialparameters(rng::TaskLocalRNG, l::GCN) # initialize model parameters
188+
# weight_c1 = l.init_weight(rng, l.hd, l.nf)
189+
# weight_c2 = l.init_weight(rng, l.nc, l.hd)
190+
# if l.use_bias
191+
# bias_c1 = l.init_bias(rng, l.hd)
192+
# bias_c2 = l.init_bias(rng, l.nc)
193+
# return (; conv1 = ( weight = weight_c1, bias = bias_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2, bias = bias_c2))
194+
# end
195+
# return (; conv1 = ( weight = weight_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2))
196+
# end
197+
198+
199+
# Now let's visualize the node embeddings of our **untrained** GCN network.
200+
201+
202+
203+
gcn = GCN(num_features, num_classes, hidden_channels, drop_rate)
204+
ps, st = Lux.setup(rng, gcn)
205+
h_untrained, st = gcn(g, x, ps, st)
206+
h_untrained = h_untrained |> transpose
207+
visualize_tsne(h_untrained, g.ndata.targets)
208+
209+
210+
# We certainly can do better by training our model.
211+
# The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model.
212+
213+
214+
215+
function custom_loss(gcn, ps, st, tuple)
216+
g, x, y = tuple
217+
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
218+
ŷ, st = gcn(g, x, ps, st)
219+
return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0
220+
end
221+
222+
function train_model!(gcn, ps, st, g, x, y)
223+
train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2))
224+
for iter in 1:2000
225+
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, x, y), train_state)
226+
227+
if iter % 100 == 0
228+
println("Epoch: $(iter) Loss: $(loss)")
229+
end
230+
end
231+
232+
return gcn, ps, st
233+
end
234+
235+
gcn, ps, st = train_model!(gcn, ps, st, g, x, y);
236+
237+
238+
239+
# Now let's evaluate the loss of our trained GCN.
240+
241+
function accuracy(model, g, x, ps, st, y, mask)
242+
st = Lux.testmode(st)
243+
ŷ, st = model(g, x, ps, st)
244+
mean(onecold(ŷ)[mask] .== onecold(y)[mask])
245+
end
246+
247+
train_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, train_mask)
248+
test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask)
249+
250+
println("Train accuracy: $(train_accuracy)")
251+
println("Test accuracy: $(test_accuracy)")
252+
253+
# **There it is!**
254+
# By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**!
255+
# This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance.
256+
257+
# We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category.
258+
259+
260+
261+
st = Lux.testmode(st) # inference mode
262+
263+
out_trained, st = gcn(g, x, ps, st)
264+
out_trained = out_trained|> transpose
265+
visualize_tsne(out_trained, g.ndata.targets)
266+
267+
# ## (Optional) Exercises
268+
269+
# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**.
270+
271+
# 2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all?
272+
273+
# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head.
274+
275+
276+
# ## Conclusion
277+
# In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification.

0 commit comments

Comments
 (0)