|  | 
|  | 1 | +# Graph Classification with Graph Neural Networks | 
|  | 2 | + | 
|  | 3 | +*This tutorial is a Julia adaptation of the Pytorch Geometric tutorial that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* | 
|  | 4 | + | 
|  | 5 | +In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**. | 
|  | 6 | +Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties. | 
|  | 7 | +Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand. | 
|  | 8 | + | 
|  | 9 | +The most common task for graph classification is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not. | 
|  | 10 | + | 
|  | 11 | +The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl. | 
|  | 12 | +Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**: | 
|  | 13 | + | 
|  | 14 | +````julia | 
|  | 15 | +using Lux, GNNLux | 
|  | 16 | +using MLDatasets, MLUtils | 
|  | 17 | +using LinearAlgebra, Random, Statistics | 
|  | 18 | +using Zygote, Optimisers, OneHotArrays | 
|  | 19 | + | 
|  | 20 | +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"  # don't ask for dataset download confirmation | 
|  | 21 | +rng = Random.seed!(42); # for reproducibility | 
|  | 22 | + | 
|  | 23 | +dataset = TUDataset("MUTAG") | 
|  | 24 | +```` | 
|  | 25 | + | 
|  | 26 | +```` | 
|  | 27 | +dataset TUDataset: | 
|  | 28 | +  name        =>    MUTAG | 
|  | 29 | +  metadata    =>    Dict{String, Any} with 1 entry | 
|  | 30 | +  graphs      =>    188-element Vector{MLDatasets.Graph} | 
|  | 31 | +  graph_data  =>    (targets = "188-element Vector{Int64}",) | 
|  | 32 | +  num_nodes   =>    3371 | 
|  | 33 | +  num_edges   =>    7442 | 
|  | 34 | +  num_graphs  =>    188 | 
|  | 35 | +```` | 
|  | 36 | + | 
|  | 37 | +````julia | 
|  | 38 | +dataset.graph_data.targets |> union | 
|  | 39 | +```` | 
|  | 40 | + | 
|  | 41 | +```` | 
|  | 42 | +2-element Vector{Int64}: | 
|  | 43 | +  1 | 
|  | 44 | + -1 | 
|  | 45 | +```` | 
|  | 46 | + | 
|  | 47 | +````julia | 
|  | 48 | +g1, y1 = dataset[1] # get the first graph and target | 
|  | 49 | +```` | 
|  | 50 | + | 
|  | 51 | +```` | 
|  | 52 | +(graphs = Graph(17, 38), targets = 1) | 
|  | 53 | +```` | 
|  | 54 | + | 
|  | 55 | +````julia | 
|  | 56 | +reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union | 
|  | 57 | +```` | 
|  | 58 | + | 
|  | 59 | +```` | 
|  | 60 | +7-element Vector{Int64}: | 
|  | 61 | + 0 | 
|  | 62 | + 1 | 
|  | 63 | + 2 | 
|  | 64 | + 3 | 
|  | 65 | + 4 | 
|  | 66 | + 5 | 
|  | 67 | + 6 | 
|  | 68 | +```` | 
|  | 69 | + | 
|  | 70 | +````julia | 
|  | 71 | +reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union | 
|  | 72 | +```` | 
|  | 73 | + | 
|  | 74 | +```` | 
|  | 75 | +4-element Vector{Int64}: | 
|  | 76 | + 0 | 
|  | 77 | + 1 | 
|  | 78 | + 2 | 
|  | 79 | + 3 | 
|  | 80 | +```` | 
|  | 81 | + | 
|  | 82 | +This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**. | 
|  | 83 | + | 
|  | 84 | +By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**. | 
|  | 85 | +It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes). | 
|  | 86 | +However, for the sake of simplicity, we will not make use of edge labels. | 
|  | 87 | + | 
|  | 88 | +We now convert the `MLDatasets.jl` graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict): | 
|  | 89 | + | 
|  | 90 | +````julia | 
|  | 91 | +graphs = mldataset2gnngraph(dataset) | 
|  | 92 | +graphs = [GNNGraph(g, | 
|  | 93 | +                    ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)), | 
|  | 94 | +                    edata = nothing) | 
|  | 95 | +            for g in graphs] | 
|  | 96 | +y = onehotbatch(dataset.graph_data.targets, [-1, 1]) | 
|  | 97 | +```` | 
|  | 98 | + | 
|  | 99 | +```` | 
|  | 100 | +2×188 OneHotMatrix(::Vector{UInt32}) with eltype Bool: | 
|  | 101 | + ⋅  1  1  ⋅  1  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  1  ⋅  1  1  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  1  1  ⋅  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  ⋅  1  1  1  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  1  ⋅  1  1  ⋅  1  ⋅  ⋅  1  1  ⋅  ⋅  1  1  ⋅  ⋅  ⋅  ⋅  1  1  1  1  1  ⋅  1  ⋅  ⋅  1  1  ⋅  1  1  1  1  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  1  1  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  1  1  ⋅  ⋅  1  1  ⋅  1 | 
|  | 102 | + 1  ⋅  ⋅  1  ⋅  1  ⋅  1  ⋅  1  1  1  1  ⋅  1  1  ⋅  1  ⋅  1  1  1  1  1  1  1  1  1  1  1  1  1  1  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  1  1  1  1  1  1  1  1  1  1  1  1  ⋅  1  1  1  1  1  1  ⋅  1  1  ⋅  ⋅  1  1  1  ⋅  1  1  ⋅  1  1  ⋅  ⋅  ⋅  1  1  1  1  1  ⋅  1  1  1  ⋅  ⋅  1  1  1  1  1  1  1  1  ⋅  1  ⋅  1  1  1  1  1  1  1  1  1  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  1  1  ⋅  ⋅  1  1  ⋅  ⋅  1  1  1  1  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  1  1  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  1  1  ⋅  1  1  ⋅  1  1  1  ⋅  ⋅  ⋅  1  1  1  ⋅  1  1  1  1  1  1  1  ⋅  1  1  1  1  1  1  ⋅  1  1  1  ⋅  1  ⋅  ⋅  1  1  ⋅  ⋅  1  ⋅ | 
|  | 103 | +```` | 
|  | 104 | + | 
|  | 105 | +We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing: | 
|  | 106 | + | 
|  | 107 | +````julia | 
|  | 108 | +train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs | 
|  | 109 | + | 
|  | 110 | + | 
|  | 111 | +train_loader = DataLoader(train_data, batchsize = 32, shuffle = true) | 
|  | 112 | +test_loader = DataLoader(test_data, batchsize = 32, shuffle = false) | 
|  | 113 | +```` | 
|  | 114 | + | 
|  | 115 | +```` | 
|  | 116 | +2-element DataLoader(::Tuple{Vector{GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32) | 
|  | 117 | +  with first element: | 
|  | 118 | +  (32-element Vector{GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, 2×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,) | 
|  | 119 | +```` | 
|  | 120 | + | 
|  | 121 | +Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all $4 \cdot 32+22 = 150$ graphs. | 
|  | 122 | + | 
|  | 123 | +## Mini-batching of graphs | 
|  | 124 | + | 
|  | 125 | +Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization. | 
|  | 126 | +In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. | 
|  | 127 | +The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`. | 
|  | 128 | + | 
|  | 129 | +However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. | 
|  | 130 | +Therefore, GNNLux.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension). | 
|  | 131 | + | 
|  | 132 | +This procedure has some crucial advantages over other batching procedures: | 
|  | 133 | + | 
|  | 134 | +1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs. | 
|  | 135 | + | 
|  | 136 | +2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges. | 
|  | 137 | + | 
|  | 138 | +GNNLux.jl can **batch multiple graphs into a single giant graph**: | 
|  | 139 | + | 
|  | 140 | +````julia | 
|  | 141 | +vec_gs, _ = first(train_loader) | 
|  | 142 | +```` | 
|  | 143 | + | 
|  | 144 | +```` | 
|  | 145 | +(GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(11, 22) with x: 7×11 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(25, 56) with x: 7×25 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(28, 66) with x: 7×28 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(12, 24) with x: 7×12 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(27, 66) with x: 7×27 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(24, 50) with x: 7×24 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(16, 36) with x: 7×16 data], Bool[1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 1 1 1 1 1 0; 0 1 1 1 1 1 1 1 0 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 0 0 0 0 0 1]) | 
|  | 146 | +```` | 
|  | 147 | + | 
|  | 148 | +````julia | 
|  | 149 | +MLUtils.batch(vec_gs) | 
|  | 150 | +```` | 
|  | 151 | + | 
|  | 152 | +```` | 
|  | 153 | +GNNGraph: | 
|  | 154 | +  num_nodes: 570 | 
|  | 155 | +  num_edges: 1254 | 
|  | 156 | +  num_graphs: 32 | 
|  | 157 | +  ndata: | 
|  | 158 | +    x = 7×570 Matrix{Float32} | 
|  | 159 | +```` | 
|  | 160 | + | 
|  | 161 | +Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch: | 
|  | 162 | + | 
|  | 163 | +```math | 
|  | 164 | +\textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ] | 
|  | 165 | +``` | 
|  | 166 | + | 
|  | 167 | +## Training a Graph Neural Network (GNN) | 
|  | 168 | + | 
|  | 169 | +Training a GNN for graph classification usually follows a simple recipe: | 
|  | 170 | + | 
|  | 171 | +1. Embed each node by performing multiple rounds of message passing | 
|  | 172 | +2. Aggregate node embeddings into a unified graph embedding (**readout layer**) | 
|  | 173 | +3. Train a final classifier on the graph embedding | 
|  | 174 | + | 
|  | 175 | +There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings: | 
|  | 176 | + | 
|  | 177 | +```math | 
|  | 178 | +\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v | 
|  | 179 | +``` | 
|  | 180 | + | 
|  | 181 | +GNNLux.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`. | 
|  | 182 | + | 
|  | 183 | +The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training: | 
|  | 184 | + | 
|  | 185 | +````julia | 
|  | 186 | +function create_model(nin, nh, nout) | 
|  | 187 | +    GNNChain(GCNConv(nin => nh, relu), | 
|  | 188 | +             GCNConv(nh => nh, relu), | 
|  | 189 | +             GCNConv(nh => nh), | 
|  | 190 | +             GlobalPool(mean), | 
|  | 191 | +             Dropout(0.5), | 
|  | 192 | +             Dense(nh, nout)) | 
|  | 193 | +end; | 
|  | 194 | + | 
|  | 195 | +nin = 7 | 
|  | 196 | +nh = 64 | 
|  | 197 | +nout = 2 | 
|  | 198 | +model = create_model(nin, nh, nout) | 
|  | 199 | + | 
|  | 200 | +ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model); | 
|  | 201 | +```` | 
|  | 202 | + | 
|  | 203 | +```` | 
|  | 204 | +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. | 
|  | 205 | +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 | 
|  | 206 | +
 | 
|  | 207 | +```` | 
|  | 208 | + | 
|  | 209 | +Here, we again make use of the `GCNConv` with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer. | 
|  | 210 | + | 
|  | 211 | +Let's train our network for a few epochs to see how well it performs on the training as well as test set: | 
|  | 212 | + | 
|  | 213 | +````julia | 
|  | 214 | +function custom_loss(model, ps, st, tuple) | 
|  | 215 | +    g, x, y = tuple | 
|  | 216 | +    logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) | 
|  | 217 | +    st = Lux.trainmode(st) | 
|  | 218 | +    ŷ, st = model(g, x, ps, st) | 
|  | 219 | +    return  logitcrossentropy(ŷ, y), (; layers = st), 0 | 
|  | 220 | +end | 
|  | 221 | + | 
|  | 222 | +function eval_loss_accuracy(model, ps, st, data_loader) | 
|  | 223 | +    loss = 0.0 | 
|  | 224 | +    acc = 0.0 | 
|  | 225 | +    ntot = 0 | 
|  | 226 | +    logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) | 
|  | 227 | +    for (g, y) in data_loader | 
|  | 228 | +        g = MLUtils.batch(g) | 
|  | 229 | +        n = length(y) | 
|  | 230 | +        ŷ, _ = model(g, g.ndata.x, ps, st) | 
|  | 231 | +        loss += logitcrossentropy(ŷ, y) * n | 
|  | 232 | +        acc += mean((ŷ .> 0) .== y) * n | 
|  | 233 | +        ntot += n | 
|  | 234 | +    end | 
|  | 235 | +    return (loss = round(loss / ntot, digits = 4), | 
|  | 236 | +            acc = round(acc * 100 / ntot, digits = 2)) | 
|  | 237 | +end | 
|  | 238 | + | 
|  | 239 | +function train_model!(model, ps, st; epochs = 500, infotime = 100) | 
|  | 240 | +    train_state = Lux.Training.TrainState(model, ps, st, Adam(1e-2)) | 
|  | 241 | + | 
|  | 242 | +    function report(epoch) | 
|  | 243 | +        train = eval_loss_accuracy(model, ps, st, train_loader) | 
|  | 244 | +        st = Lux.testmode(st) | 
|  | 245 | +        test = eval_loss_accuracy(model, ps, st, test_loader) | 
|  | 246 | +        st = Lux.trainmode(st) | 
|  | 247 | +        @info (; epoch, train, test) | 
|  | 248 | +    end | 
|  | 249 | +    report(0) | 
|  | 250 | +    for iter in 1:epochs | 
|  | 251 | +        for (g, y) in train_loader | 
|  | 252 | +            g = MLUtils.batch(g) | 
|  | 253 | +            _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.ndata.x, y), train_state) | 
|  | 254 | +        end | 
|  | 255 | + | 
|  | 256 | +        iter % infotime == 0 && report(iter) | 
|  | 257 | +    end | 
|  | 258 | +    return model, ps, st | 
|  | 259 | +end | 
|  | 260 | + | 
|  | 261 | +model, ps, st = train_model!(model, ps, st); | 
|  | 262 | +```` | 
|  | 263 | + | 
|  | 264 | +```` | 
|  | 265 | +┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code. | 
|  | 266 | +└ @ LuxLib.Utils ~/.julia/packages/LuxLib/ru5RQ/src/utils.jl:314 | 
|  | 267 | +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. | 
|  | 268 | +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 | 
|  | 269 | +[ Info: (epoch = 0, train = (loss = 0.6934, acc = 51.67), test = (loss = 0.6902, acc = 50.0)) | 
|  | 270 | +[ Info: (epoch = 100, train = (loss = 0.3979, acc = 81.33), test = (loss = 0.5769, acc = 69.74)) | 
|  | 271 | +[ Info: (epoch = 200, train = (loss = 0.3904, acc = 84.0), test = (loss = 0.6402, acc = 65.79)) | 
|  | 272 | +[ Info: (epoch = 300, train = (loss = 0.3813, acc = 85.33), test = (loss = 0.6331, acc = 69.74)) | 
|  | 273 | +[ Info: (epoch = 400, train = (loss = 0.3682, acc = 85.0), test = (loss = 0.7273, acc = 69.74)) | 
|  | 274 | +[ Info: (epoch = 500, train = (loss = 0.3561, acc = 86.67), test = (loss = 0.6825, acc = 73.68)) | 
|  | 275 | +
 | 
|  | 276 | +```` | 
|  | 277 | + | 
|  | 278 | +As one can see, our model reaches around **74% test accuracy**. | 
|  | 279 | +Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets. | 
|  | 280 | + | 
|  | 281 | +## (Optional) Exercise | 
|  | 282 | + | 
|  | 283 | +Can we do better than this? | 
|  | 284 | +As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**. | 
|  | 285 | +An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information: | 
|  | 286 | + | 
|  | 287 | +```math | 
|  | 288 | +\mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)} | 
|  | 289 | +``` | 
|  | 290 | + | 
|  | 291 | +This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl. | 
|  | 292 | + | 
|  | 293 | +As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`. | 
|  | 294 | +This should bring you close to **82% test accuracy**. | 
|  | 295 | + | 
|  | 296 | +## Conclusion | 
|  | 297 | + | 
|  | 298 | +In this chapter, you have learned how to apply GNNs to the task of graph classification. | 
|  | 299 | +You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings. | 
|  | 300 | + | 
|  | 301 | +--- | 
|  | 302 | + | 
|  | 303 | +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* | 
|  | 304 | + | 
0 commit comments