|
| 1 | +# # Node Classification with Graph Neural Networks |
| 2 | + |
| 3 | +# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning). |
| 4 | + |
| 5 | + |
| 6 | +# ## Import |
| 7 | +# Let us start off by importing some libraries. We will be using `Lux.jl` and `GNNLux.jl` for our tutorial. |
| 8 | + |
| 9 | +using Lux, GNNLux |
| 10 | +using MLDatasets |
| 11 | +using Plots, TSne |
| 12 | +using Random, Statistics |
| 13 | +using Zygote, Optimisers, OneHotArrays |
| 14 | + |
| 15 | + |
| 16 | +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation |
| 17 | +rng = Random.seed!(17) # for reproducibility |
| 18 | + |
| 19 | +# ## Visualize |
| 20 | +# We want to visualize the outputs of the results using t-distributed stochastic neighbor embedding (tsne) to embed our output embeddings onto a 2D plane. |
| 21 | + |
| 22 | +function visualize_tsne(out, targets) |
| 23 | + z = tsne(out, 2) |
| 24 | + scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) |
| 25 | +end; |
| 26 | + |
| 27 | + |
| 28 | +# ## Dataset: Cora |
| 29 | + |
| 30 | +# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents classified into one of seven classes and 5429 links. Each node represents articles/documents and the edges between them when they cite each other. |
| 31 | + |
| 32 | +# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. |
| 33 | + |
| 34 | +# This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. |
| 35 | + |
| 36 | +dataset = Cora() |
| 37 | + |
| 38 | +# Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. |
| 39 | + |
| 40 | +dataset.metadata |
| 41 | + |
| 42 | +# The `graphs` variable GraphDataset contains the graph. The `Cora` dataset contains only 1 graph. |
| 43 | + |
| 44 | +dataset.graphs |
| 45 | + |
| 46 | + |
| 47 | +# There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. |
| 48 | + |
| 49 | +g = mldataset2gnngraph(dataset) |
| 50 | + |
| 51 | + |
| 52 | +println("Number of nodes: $(g.num_nodes)") |
| 53 | +println("Number of edges: $(g.num_edges)") |
| 54 | +println("Average node degree: $(g.num_edges / g.num_nodes)") |
| 55 | +println("Number of training nodes: $(sum(g.ndata.train_mask))") |
| 56 | +println("Training node label rate: $(mean(g.ndata.train_mask))") |
| 57 | +println("Has isolated nodes: $(has_isolated_nodes(g))") |
| 58 | +println("Has self-loops: $(has_self_loops(g))") |
| 59 | +println("Is undirected: $(is_bidirected(g))") |
| 60 | + |
| 61 | + |
| 62 | + |
| 63 | +# Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. |
| 64 | +# We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. |
| 65 | +# For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). |
| 66 | +# This results in a training node label rate of only 5%. |
| 67 | + |
| 68 | +# We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). |
| 69 | + |
| 70 | +x = g.ndata.features # we onehot encode both the node labels (what we want to predict): |
| 71 | +y = onehotbatch(g.ndata.targets, 1:7) |
| 72 | +train_mask = g.ndata.train_mask; |
| 73 | +num_features = size(x)[1]; |
| 74 | +hidden_channels = 16; |
| 75 | +drop_rate = 0.5; |
| 76 | +num_classes = dataset.metadata["num_classes"]; |
| 77 | + |
| 78 | + |
| 79 | +# ## Multi-layer Perception Network (MLP) |
| 80 | + |
| 81 | +# In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. |
| 82 | + |
| 83 | +# Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): |
| 84 | + |
| 85 | +MLP = Chain(Dense(num_features => hidden_channels, relu), |
| 86 | + Dropout(drop_rate), |
| 87 | + Dense(hidden_channels => num_classes)) |
| 88 | + |
| 89 | +ps, st = Lux.setup(rng, MLP) |
| 90 | + |
| 91 | +# ### Training a Multilayer Perceptron |
| 92 | + |
| 93 | +# Our MLP is defined by two linear layers and enhanced by [ReLU](https://lux.csail.mit.edu/stable/api/NN_Primitives/ActivationFunctions#NNlib.relu) non-linearity and [Dropout](https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Dropout). |
| 94 | +# Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. |
| 95 | + |
| 96 | +# Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/tutorials/gnn_intro/). |
| 97 | +# We again make use of the **cross entropy loss** and **Adam optimizer**. |
| 98 | +# This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). |
| 99 | + |
| 100 | + |
| 101 | +function custom_loss(model, ps, st, x) |
| 102 | + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) |
| 103 | + ŷ, st = model(x, ps, st) |
| 104 | + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 |
| 105 | +end |
| 106 | + |
| 107 | +function train_model!(MLP, ps, st, x, epochs) |
| 108 | + train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3)) |
| 109 | + for iter in 1:epochs |
| 110 | + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, x, train_state) |
| 111 | + |
| 112 | + if iter % 100 == 0 |
| 113 | + println("Epoch: $(iter) Loss: $(loss)") |
| 114 | + end |
| 115 | + end |
| 116 | +end |
| 117 | + |
| 118 | +function accuracy(model, x, ps, st, y, mask) |
| 119 | + st = Lux.testmode(st) |
| 120 | + ŷ, st = model(x, ps, st) |
| 121 | + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) |
| 122 | +end |
| 123 | + |
| 124 | +train_model!(MLP, ps, st, x, 2000) |
| 125 | + |
| 126 | + |
| 127 | + |
| 128 | +# After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. |
| 129 | +# Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: |
| 130 | + |
| 131 | +accuracy(MLP, x, ps, st, y, .!train_mask) |
| 132 | + |
| 133 | +# As one can see, our MLP performs rather bad with only about ~50% test accuracy. |
| 134 | +# But why does the MLP do not perform better? |
| 135 | +# The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. |
| 136 | + |
| 137 | +# It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. |
| 138 | +# That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. |
| 139 | + |
| 140 | + |
| 141 | +# ## Training a Graph Convolutional Neural Network (GNN) |
| 142 | + |
| 143 | +# Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GCNConv) module. |
| 144 | +# To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as |
| 145 | + |
| 146 | +# ```math |
| 147 | +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} |
| 148 | +# ``` |
| 149 | + |
| 150 | +# where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. |
| 151 | +# In contrast, a single `Linear` layer is defined as |
| 152 | + |
| 153 | +# ```math |
| 154 | +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} |
| 155 | +# ``` |
| 156 | + |
| 157 | +# which does not make use of neighboring node information. |
| 158 | + |
| 159 | +Lux.@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} |
| 160 | + nf::Int |
| 161 | + nc::Int |
| 162 | + hd::Int |
| 163 | + conv1 |
| 164 | + conv2 |
| 165 | + drop |
| 166 | + use_bias::Bool |
| 167 | + init_weight |
| 168 | + init_bias |
| 169 | +end |
| 170 | + |
| 171 | +function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor |
| 172 | + conv1 = GCNConv(num_features => hidden_channels) |
| 173 | + conv2 = GCNConv(hidden_channels => num_classes) |
| 174 | + drop = Dropout(drop_rate) |
| 175 | + return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias) |
| 176 | +end |
| 177 | + |
| 178 | +function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass |
| 179 | + x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1) |
| 180 | + x = relu.(x) |
| 181 | + x, stdrop = gcn.drop(x, ps.drop, st.drop) |
| 182 | + x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2) |
| 183 | + return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2) |
| 184 | +end |
| 185 | + |
| 186 | + |
| 187 | +# function LuxCore.initialparameters(rng::TaskLocalRNG, l::GCN) # initialize model parameters |
| 188 | +# weight_c1 = l.init_weight(rng, l.hd, l.nf) |
| 189 | +# weight_c2 = l.init_weight(rng, l.nc, l.hd) |
| 190 | +# if l.use_bias |
| 191 | +# bias_c1 = l.init_bias(rng, l.hd) |
| 192 | +# bias_c2 = l.init_bias(rng, l.nc) |
| 193 | +# return (; conv1 = ( weight = weight_c1, bias = bias_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2, bias = bias_c2)) |
| 194 | +# end |
| 195 | +# return (; conv1 = ( weight = weight_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2)) |
| 196 | +# end |
| 197 | + |
| 198 | + |
| 199 | +# Now let's visualize the node embeddings of our **untrained** GCN network. |
| 200 | + |
| 201 | + |
| 202 | + |
| 203 | +gcn = GCN(num_features, num_classes, hidden_channels, drop_rate) |
| 204 | +ps, st = Lux.setup(rng, gcn) |
| 205 | +h_untrained, st = gcn(g, x, ps, st) |
| 206 | +h_untrained = h_untrained |> transpose |
| 207 | +visualize_tsne(h_untrained, g.ndata.targets) |
| 208 | + |
| 209 | + |
| 210 | +# We certainly can do better by training our model. |
| 211 | +# The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. |
| 212 | + |
| 213 | + |
| 214 | + |
| 215 | +function custom_loss(gcn, ps, st, tuple) |
| 216 | + g, x, y = tuple |
| 217 | + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) |
| 218 | + ŷ, st = gcn(g, x, ps, st) |
| 219 | + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 |
| 220 | +end |
| 221 | + |
| 222 | +function train_model!(gcn, ps, st, g, x, y) |
| 223 | + train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2)) |
| 224 | + for iter in 1:2000 |
| 225 | + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, x, y), train_state) |
| 226 | + |
| 227 | + if iter % 100 == 0 |
| 228 | + println("Epoch: $(iter) Loss: $(loss)") |
| 229 | + end |
| 230 | + end |
| 231 | + |
| 232 | + return gcn, ps, st |
| 233 | +end |
| 234 | + |
| 235 | +gcn, ps, st = train_model!(gcn, ps, st, g, x, y); |
| 236 | + |
| 237 | + |
| 238 | + |
| 239 | +# Now let's evaluate the loss of our trained GCN. |
| 240 | + |
| 241 | +function accuracy(model, g, x, ps, st, y, mask) |
| 242 | + st = Lux.testmode(st) |
| 243 | + ŷ, st = model(g, x, ps, st) |
| 244 | + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) |
| 245 | +end |
| 246 | + |
| 247 | +train_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, train_mask) |
| 248 | +test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask) |
| 249 | + |
| 250 | +println("Train accuracy: $(train_accuracy)") |
| 251 | +println("Test accuracy: $(test_accuracy)") |
| 252 | + |
| 253 | +# **There it is!** |
| 254 | +# By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**! |
| 255 | +# This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. |
| 256 | + |
| 257 | +# We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. |
| 258 | + |
| 259 | + |
| 260 | + |
| 261 | +st = Lux.testmode(st) # inference mode |
| 262 | + |
| 263 | +out_trained, st = gcn(g, x, ps, st) |
| 264 | +out_trained = out_trained|> transpose |
| 265 | +visualize_tsne(out_trained, g.ndata.targets) |
| 266 | + |
| 267 | +# ## (Optional) Exercises |
| 268 | + |
| 269 | +# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**. |
| 270 | + |
| 271 | +# 2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? |
| 272 | + |
| 273 | +# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. |
| 274 | + |
| 275 | + |
| 276 | +# ## Conclusion |
| 277 | +# In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification. |
0 commit comments