|
1 |
| -```@meta |
2 |
| -EditURL = "../../src_tutorials/introductory_tutorials/temporal_graph_classification.jl" |
3 |
| -``` |
4 |
| - |
5 | 1 | # Temporal Graph classification with GraphNeuralNetworks.jl
|
6 | 2 |
|
7 | 3 | In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying.
|
@@ -86,8 +82,7 @@ First, we start by adapting the `GlobalPool` to the `TemporalSnapshotsGNNGraphs`
|
86 | 82 | ````julia
|
87 | 83 | function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
|
88 | 84 | h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)]
|
89 |
| - sze = size(h[1]) |
90 |
| - reshape(reduce(hcat, h), sze[1], length(h)) |
| 85 | + return mean(h) |
91 | 86 | end
|
92 | 87 | ````
|
93 | 88 |
|
|
114 | 109 | function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph)
|
115 | 110 | h = m.gin(g, g.ndata.x)
|
116 | 111 | h = m.globalpool(g, h)
|
117 |
| - h = mean(h, dims=2) |
118 | 112 | return m.dense(h)
|
119 | 113 | end
|
120 | 114 | ````
|
@@ -160,35 +154,28 @@ function train(dataset)
|
160 | 154 | end
|
161 | 155 | Flux.update!(opt, model, grads[1])
|
162 | 156 | end
|
163 |
| - if epoch % 10 == 0 |
| 157 | + if epoch % 20 == 0 |
164 | 158 | report(epoch)
|
165 | 159 | end
|
166 | 160 | end
|
167 | 161 | return model
|
168 | 162 | end
|
169 | 163 |
|
170 |
| - |
171 | 164 | train(brain_dataset);
|
172 |
| - |
173 |
| -# Conclusions |
174 | 165 | ````
|
175 | 166 |
|
176 | 167 | ````
|
177 | 168 | Epoch: 0 (train_loss = 0.80321693f0, train_acc = 50.5) (test_loss = 0.79863846f0, test_acc = 60.0)
|
178 |
| -Epoch: 10 (train_loss = 0.61757874f0, train_acc = 63.5) (test_loss = 0.6142881f0, test_acc = 72.0) |
179 |
| -Epoch: 20 (train_loss = 0.50907505f0, train_acc = 74.0) (test_loss = 0.646904f0, test_acc = 60.0) |
180 |
| -Epoch: 30 (train_loss = 0.35090268f0, train_acc = 81.0) (test_loss = 0.65224814f0, test_acc = 60.0) |
181 |
| -Epoch: 40 (train_loss = 0.13825743f0, train_acc = 97.0) (test_loss = 0.58508986f0, test_acc = 74.0) |
182 |
| -Epoch: 50 (train_loss = 0.44244948f0, train_acc = 77.0) (test_loss = 1.5108807f0, test_acc = 62.0) |
183 |
| -Epoch: 60 (train_loss = 0.033900682f0, train_acc = 99.5) (test_loss = 0.593368f0, test_acc = 78.0) |
184 |
| -Epoch: 70 (train_loss = 0.04119176f0, train_acc = 99.5) (test_loss = 0.4229265f0, test_acc = 84.0) |
185 |
| -Epoch: 80 (train_loss = 0.018655278f0, train_acc = 99.5) (test_loss = 0.5038431f0, test_acc = 88.0) |
186 |
| -Epoch: 90 (train_loss = 0.0074938983f0, train_acc = 100.0) (test_loss = 0.5612772f0, test_acc = 88.0) |
187 |
| -Epoch: 100 (train_loss = 0.021453373f0, train_acc = 99.5) (test_loss = 0.4984316f0, test_acc = 84.0) |
| 169 | +Epoch: 20 (train_loss = 0.5073769f0, train_acc = 74.5) (test_loss = 0.64655066f0, test_acc = 60.0) |
| 170 | +Epoch: 40 (train_loss = 0.13417317f0, train_acc = 96.5) (test_loss = 0.5689327f0, test_acc = 74.0) |
| 171 | +Epoch: 60 (train_loss = 0.01875147f0, train_acc = 100.0) (test_loss = 0.45651233f0, test_acc = 82.0) |
| 172 | +Epoch: 80 (train_loss = 0.12695672f0, train_acc = 95.0) (test_loss = 0.65159386f0, test_acc = 82.0) |
| 173 | +Epoch: 100 (train_loss = 0.036399372f0, train_acc = 99.0) (test_loss = 0.6491585f0, test_acc = 86.0) |
188 | 174 |
|
189 | 175 | ````
|
190 | 176 |
|
191 |
| -In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 80%, but can be improved by fine-tuning the parameters and training on more data. |
| 177 | +# Conclusions |
| 178 | +In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 85%, but can be improved by fine-tuning the parameters and training on more data. |
192 | 179 |
|
193 | 180 | ---
|
194 | 181 |
|
|
0 commit comments