Skip to content

Commit 57141aa

Browse files
committed
Fixes
1 parent a7f5252 commit 57141aa

File tree

2 files changed

+14
-29
lines changed

2 files changed

+14
-29
lines changed

GraphNeuralNetworks/docs/src/tutorials/temporal_graph_classification.md

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
```@meta
2-
EditURL = "../../src_tutorials/introductory_tutorials/temporal_graph_classification.jl"
3-
```
4-
51
# Temporal Graph classification with GraphNeuralNetworks.jl
62

73
In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying.
@@ -86,8 +82,7 @@ First, we start by adapting the `GlobalPool` to the `TemporalSnapshotsGNNGraphs`
8682
````julia
8783
function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
8884
h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)]
89-
sze = size(h[1])
90-
reshape(reduce(hcat, h), sze[1], length(h))
85+
return mean(h)
9186
end
9287
````
9388

@@ -114,7 +109,6 @@ end
114109
function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph)
115110
h = m.gin(g, g.ndata.x)
116111
h = m.globalpool(g, h)
117-
h = mean(h, dims=2)
118112
return m.dense(h)
119113
end
120114
````
@@ -160,35 +154,28 @@ function train(dataset)
160154
end
161155
Flux.update!(opt, model, grads[1])
162156
end
163-
if epoch % 10 == 0
157+
if epoch % 20 == 0
164158
report(epoch)
165159
end
166160
end
167161
return model
168162
end
169163

170-
171164
train(brain_dataset);
172-
173-
# Conclusions
174165
````
175166

176167
````
177168
Epoch: 0 (train_loss = 0.80321693f0, train_acc = 50.5) (test_loss = 0.79863846f0, test_acc = 60.0)
178-
Epoch: 10 (train_loss = 0.61757874f0, train_acc = 63.5) (test_loss = 0.6142881f0, test_acc = 72.0)
179-
Epoch: 20 (train_loss = 0.50907505f0, train_acc = 74.0) (test_loss = 0.646904f0, test_acc = 60.0)
180-
Epoch: 30 (train_loss = 0.35090268f0, train_acc = 81.0) (test_loss = 0.65224814f0, test_acc = 60.0)
181-
Epoch: 40 (train_loss = 0.13825743f0, train_acc = 97.0) (test_loss = 0.58508986f0, test_acc = 74.0)
182-
Epoch: 50 (train_loss = 0.44244948f0, train_acc = 77.0) (test_loss = 1.5108807f0, test_acc = 62.0)
183-
Epoch: 60 (train_loss = 0.033900682f0, train_acc = 99.5) (test_loss = 0.593368f0, test_acc = 78.0)
184-
Epoch: 70 (train_loss = 0.04119176f0, train_acc = 99.5) (test_loss = 0.4229265f0, test_acc = 84.0)
185-
Epoch: 80 (train_loss = 0.018655278f0, train_acc = 99.5) (test_loss = 0.5038431f0, test_acc = 88.0)
186-
Epoch: 90 (train_loss = 0.0074938983f0, train_acc = 100.0) (test_loss = 0.5612772f0, test_acc = 88.0)
187-
Epoch: 100 (train_loss = 0.021453373f0, train_acc = 99.5) (test_loss = 0.4984316f0, test_acc = 84.0)
169+
Epoch: 20 (train_loss = 0.5073769f0, train_acc = 74.5) (test_loss = 0.64655066f0, test_acc = 60.0)
170+
Epoch: 40 (train_loss = 0.13417317f0, train_acc = 96.5) (test_loss = 0.5689327f0, test_acc = 74.0)
171+
Epoch: 60 (train_loss = 0.01875147f0, train_acc = 100.0) (test_loss = 0.45651233f0, test_acc = 82.0)
172+
Epoch: 80 (train_loss = 0.12695672f0, train_acc = 95.0) (test_loss = 0.65159386f0, test_acc = 82.0)
173+
Epoch: 100 (train_loss = 0.036399372f0, train_acc = 99.0) (test_loss = 0.6491585f0, test_acc = 86.0)
188174
189175
````
190176

191-
In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 80%, but can be improved by fine-tuning the parameters and training on more data.
177+
# Conclusions
178+
In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 85%, but can be improved by fine-tuning the parameters and training on more data.
192179

193180
---
194181

GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/temporal_graph_classification.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ end
6969

7070
function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
7171
h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)]
72-
sze = size(h[1])
73-
reshape(reduce(hcat, h), sze[1], length(h))
72+
return mean(h)
7473
end
7574

7675
# Then we implement the constructor of the model, which we call `GenderPredictionModel`, and the foward pass.
@@ -95,7 +94,6 @@ end
9594
function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph)
9695
h = m.gin(g, g.ndata.x)
9796
h = m.globalpool(g, h)
98-
h = mean(h, dims=2)
9997
return m.dense(h)
10098
end
10199

@@ -139,16 +137,16 @@ function train(dataset)
139137
end
140138
Flux.update!(opt, model, grads[1])
141139
end
142-
if epoch % 10 == 0
140+
if epoch % 20 == 0
143141
report(epoch)
144142
end
145143
end
146144
return model
147145
end
148146

149-
150147
train(brain_dataset);
151148

152-
## Conclusions
149+
153150
#
154-
# In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 80%, but can be improved by fine-tuning the parameters and training on more data.
151+
# # Conclusions
152+
# In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 85%, but can be improved by fine-tuning the parameters and training on more data.

0 commit comments

Comments
 (0)