|
| 1 | +# # Node Classification with Graph Neural Networks |
| 2 | + |
| 3 | +# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning). |
| 4 | + |
| 5 | +# ## Import |
| 6 | +# Let us start off by importing some libraries. We will be using `Flux.jl` and `GraphNeuralNetworks.jl` for our tutorial. |
| 7 | + |
| 8 | +using Flux, GraphNeuralNetworks |
| 9 | +using Flux: onecold, onehotbatch, logitcrossentropy |
| 10 | +using MLDatasets |
| 11 | +using Plots, TSne |
| 12 | +using Statistics, Random |
| 13 | + |
| 14 | +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation |
| 15 | +Random.seed!(17); # for reproducibility |
| 16 | + |
| 17 | +# ## Visualize |
| 18 | +# We want to visualize our results using t-distributed stochastic neighbor embedding (tsne) to project our output onto a 2D plane. |
| 19 | + |
| 20 | +function visualize_tsne(out, targets) |
| 21 | + z = tsne(out, 2) |
| 22 | + scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) |
| 23 | +end; |
| 24 | + |
| 25 | +# ## Dataset: Cora |
| 26 | + |
| 27 | +# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents categorized into seven classes with 5,429 citation links. Each node represents an article or document, and edges between nodes indicate a citation relationship, where one cites the other. |
| 28 | + |
| 29 | +# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. |
| 30 | + |
| 31 | +# This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. |
| 32 | + |
| 33 | +dataset = Cora() |
| 34 | + |
| 35 | +# Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. |
| 36 | + |
| 37 | +dataset.metadata |
| 38 | + |
| 39 | +# The `graphs` variable contains the graph. The `Cora` dataset contains only 1 graph. |
| 40 | + |
| 41 | + |
| 42 | +dataset.graphs |
| 43 | + |
| 44 | +# There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. |
| 45 | + |
| 46 | +g = mldataset2gnngraph(dataset) |
| 47 | + |
| 48 | +println("Number of nodes: $(g.num_nodes)") |
| 49 | +println("Number of edges: $(g.num_edges)") |
| 50 | +println("Average node degree: $(g.num_edges / g.num_nodes)") |
| 51 | +println("Number of training nodes: $(sum(g.ndata.train_mask))") |
| 52 | +println("Training node label rate: $(mean(g.ndata.train_mask))") |
| 53 | +println("Has isolated nodes: $(has_isolated_nodes(g))") |
| 54 | +println("Has self-loops: $(has_self_loops(g))") |
| 55 | +println("Is undirected: $(is_bidirected(g))") |
| 56 | + |
| 57 | + |
| 58 | +# Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. |
| 59 | +# We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. |
| 60 | +# For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). |
| 61 | +# This results in a training node label rate of only 5%. |
| 62 | + |
| 63 | +# We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). |
| 64 | + |
| 65 | +x = g.ndata.features # we onehot encode both the node labels (what we want to predict): |
| 66 | +y = onehotbatch(g.ndata.targets, 1:7) |
| 67 | +train_mask = g.ndata.train_mask |
| 68 | +num_features = size(x)[1] |
| 69 | +hidden_channels = 16 |
| 70 | +num_classes = dataset.metadata["num_classes"]; |
| 71 | + |
| 72 | +# ## Multi-layer Perception Network (MLP) |
| 73 | + |
| 74 | +# In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. |
| 75 | + |
| 76 | +# Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): |
| 77 | + |
| 78 | +struct MLP |
| 79 | + layers::NamedTuple |
| 80 | +end |
| 81 | + |
| 82 | +Flux.@layer :expand MLP |
| 83 | + |
| 84 | +function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5) |
| 85 | + layers = (hidden = Dense(num_features => hidden_channels), |
| 86 | + drop = Dropout(drop_rate), |
| 87 | + classifier = Dense(hidden_channels => num_classes)) |
| 88 | + return MLP(layers) |
| 89 | +end; |
| 90 | + |
| 91 | +function (model::MLP)(x::AbstractMatrix) |
| 92 | + l = model.layers |
| 93 | + x = l.hidden(x) |
| 94 | + x = relu(x) |
| 95 | + x = l.drop(x) |
| 96 | + x = l.classifier(x) |
| 97 | + return x |
| 98 | +end; |
| 99 | + |
| 100 | +# ### Training a Multilayer Perceptron |
| 101 | + |
| 102 | +# Our MLP is defined by two linear layers and enhanced by [ReLU](https://fluxml.ai/Flux.jl/stable/models/nnlib/#NNlib.relu) non-linearity and [Dropout](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Dropout). |
| 103 | +# Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. |
| 104 | + |
| 105 | +# Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GraphNeuralNetworks.jl/stable/tutorials/gnn_intro/). |
| 106 | +# We again make use of the **cross entropy loss** and **Adam optimizer**. |
| 107 | +# This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). |
| 108 | + |
| 109 | +function train(model::MLP, data::AbstractMatrix, epochs::Int, opt) |
| 110 | + Flux.trainmode!(model) |
| 111 | + |
| 112 | + for epoch in 1:epochs |
| 113 | + loss, grad = Flux.withgradient(model) do model |
| 114 | + ŷ = model(data) |
| 115 | + logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) |
| 116 | + end |
| 117 | + |
| 118 | + Flux.update!(opt, model, grad[1]) |
| 119 | + if epoch % 200 == 0 |
| 120 | + @show epoch, loss |
| 121 | + end |
| 122 | + end |
| 123 | +end; |
| 124 | + |
| 125 | +function accuracy(model::MLP, x::AbstractMatrix, y::Flux.OneHotArray, mask::BitVector) |
| 126 | + Flux.testmode!(model) |
| 127 | + mean(onecold(model(x))[mask] .== onecold(y)[mask]) |
| 128 | +end; |
| 129 | + |
| 130 | +mlp = MLP(num_features, num_classes, hidden_channels) |
| 131 | +opt_mlp = Flux.setup(Adam(1e-3), mlp) |
| 132 | +epochs = 2000 |
| 133 | +train(mlp, g.ndata.features, epochs, opt_mlp) |
| 134 | + |
| 135 | +# After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. |
| 136 | +# Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: |
| 137 | + |
| 138 | +accuracy(mlp, g.ndata.features, y, .!train_mask) |
| 139 | + |
| 140 | + |
| 141 | +# As one can see, our MLP performs rather bad with only about ~50% test accuracy. |
| 142 | +# But why does the MLP do not perform better? |
| 143 | +# The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. |
| 144 | + |
| 145 | +# It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. |
| 146 | +# That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. |
| 147 | + |
| 148 | + |
| 149 | + |
| 150 | +# ## Training a Graph Convolutional Neural Network (GNN) |
| 151 | + |
| 152 | +# Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GraphNeuralNetworks.jl/stable/api/conv/#GraphNeuralNetworks.GCNConv) module. |
| 153 | +# To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as |
| 154 | + |
| 155 | +# ```math |
| 156 | +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} |
| 157 | +# ``` |
| 158 | + |
| 159 | +# where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. |
| 160 | +# In contrast, a single `Linear` layer is defined as |
| 161 | + |
| 162 | +# ```math |
| 163 | +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} |
| 164 | +# ``` |
| 165 | + |
| 166 | +# which does not make use of neighboring node information. |
| 167 | + |
| 168 | +struct GCN |
| 169 | + layers::NamedTuple |
| 170 | +end |
| 171 | + |
| 172 | +Flux.@layer GCN # provides parameter collection, gpu movement and more |
| 173 | + |
| 174 | +function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5) |
| 175 | + layers = (conv1 = GCNConv(num_features => hidden_channels), |
| 176 | + drop = Dropout(drop_rate), |
| 177 | + conv2 = GCNConv(hidden_channels => num_classes)) |
| 178 | + return GCN(layers) |
| 179 | +end; |
| 180 | + |
| 181 | +function (gcn::GCN)(g::GNNGraph, x::AbstractMatrix) |
| 182 | + l = gcn.layers |
| 183 | + x = l.conv1(g, x) |
| 184 | + x = relu.(x) |
| 185 | + x = l.drop(x) |
| 186 | + x = l.conv2(g, x) |
| 187 | + return x |
| 188 | +end; |
| 189 | + |
| 190 | + |
| 191 | +# Now let's visualize the node embeddings of our **untrained** GCN network. |
| 192 | + |
| 193 | +gcn = GCN(num_features, num_classes, hidden_channels) |
| 194 | +h_untrained = gcn(g, x) |> transpose |
| 195 | +visualize_tsne(h_untrained, g.ndata.targets) |
| 196 | + |
| 197 | + |
| 198 | +# We certainly can do better by training our model. |
| 199 | +# The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. |
| 200 | + |
| 201 | +function train(model::GCN, g::GNNGraph, x::AbstractMatrix, epochs::Int, opt) |
| 202 | + Flux.trainmode!(model) |
| 203 | + |
| 204 | + for epoch in 1:epochs |
| 205 | + loss, grad = Flux.withgradient(model) do model |
| 206 | + ŷ = model(g, x) |
| 207 | + logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]) |
| 208 | + end |
| 209 | + |
| 210 | + Flux.update!(opt, model, grad[1]) |
| 211 | + if epoch % 200 == 0 |
| 212 | + @show epoch, loss |
| 213 | + end |
| 214 | + end |
| 215 | +end; |
| 216 | + |
| 217 | +# |
| 218 | + |
| 219 | +mlp = MLP(num_features, num_classes, hidden_channels) |
| 220 | +opt_mlp = Flux.setup(Adam(1e-3), mlp) |
| 221 | +epochs = 2000 |
| 222 | +train(mlp, g.ndata.features, epochs, opt_mlp) |
| 223 | + |
| 224 | +# |
| 225 | +function accuracy(model::GCN, g::GNNGraph, x::AbstractMatrix, y::Flux.OneHotArray, |
| 226 | + mask::BitVector) |
| 227 | + Flux.testmode!(model) |
| 228 | + mean(onecold(model(g, x))[mask] .== onecold(y)[mask]) |
| 229 | +end |
| 230 | + |
| 231 | +# |
| 232 | + |
| 233 | +accuracy(mlp, g.ndata.features, y, .!train_mask) |
| 234 | + |
| 235 | +# |
| 236 | + |
| 237 | +opt_gcn = Flux.setup(Adam(1e-2), gcn) |
| 238 | +train(gcn, g, x, epochs, opt_gcn) |
| 239 | + |
| 240 | + |
| 241 | +# Now let's evaluate the loss of our trained GCN. |
| 242 | + |
| 243 | +train_accuracy = accuracy(gcn, g, g.ndata.features, y, train_mask) |
| 244 | +test_accuracy = accuracy(gcn, g, g.ndata.features, y, .!train_mask) |
| 245 | + |
| 246 | +println("Train accuracy: $(train_accuracy)") |
| 247 | +println("Test accuracy: $(test_accuracy)") |
| 248 | + |
| 249 | + |
| 250 | +# **There it is!** |
| 251 | +# By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**! |
| 252 | +# This is in stark contrast to the 59% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. |
| 253 | + |
| 254 | +# We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. |
| 255 | + |
| 256 | + |
| 257 | +Flux.testmode!(gcn) # inference mode |
| 258 | + |
| 259 | +out_trained = gcn(g, x) |> transpose |
| 260 | +visualize_tsne(out_trained, g.ndata.targets) |
| 261 | + |
| 262 | + |
| 263 | + |
| 264 | +# ## (Optional) Exercises |
| 265 | + |
| 266 | +# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**. |
| 267 | + |
| 268 | +# 2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? |
| 269 | + |
| 270 | +# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GraphNeuralNetworks.jl/stable/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. |
| 271 | + |
| 272 | + |
| 273 | + |
| 274 | +# ## Conclusion |
| 275 | +# In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification. |
0 commit comments