Skip to content

Commit 2a2cddd

Browse files
committed
Fix temporal graph classification literate
1 parent f2a7600 commit 2a2cddd

File tree

4 files changed

+209
-9
lines changed

4 files changed

+209
-9
lines changed

GNNGraphs/src/temporalsnapshotsgnngraph.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ GNNGraph:
5454
```
5555
"""
5656
struct TemporalSnapshotsGNNGraph{G<:GNNGraph, D<:DataStore}
57-
num_nodes::Vector{Int}
58-
num_edges::Vector{Int}
57+
num_nodes::AbstractVector{Int}
58+
num_edges::AbstractVector{Int}
5959
num_snapshots::Int
60-
snapshots::Vector{G}
60+
snapshots::AbstractVector{G}
6161
tgdata::D
6262
end
6363

GraphNeuralNetworks/docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ makedocs(;
6666
],
6767
"Temporal graph neural networks" =>[
6868
"Node autoregression" => "tutorials/traffic_prediction.md",
69-
"Temporal graph classification" => "tutorials/temporal_graph_classification_pluto.md"
69+
"Temporal graph classification" => "tutorials/temporal_graph_classification.md"
7070
],
7171
],
7272

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
```@meta
2+
EditURL = "../../src_tutorials/introductory_tutorials/temporal_graph_classification.jl"
3+
```
4+
5+
# Temporal Graph classification with GraphNeuralNetworks.jl
6+
7+
In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying.
8+
9+
We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. Given the large amount of data, we will implement the training so that it can also run on the GPU.
10+
11+
## Import
12+
13+
We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others.
14+
15+
````julia
16+
using Flux
17+
using GraphNeuralNetworks
18+
using Statistics, Random
19+
using LinearAlgebra
20+
using MLDatasets: TemporalBrains
21+
using CUDA # comment out if you don't have a CUDA GPU
22+
23+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
24+
Random.seed!(17); # for reproducibility
25+
````
26+
27+
## Dataset: TemporalBrains
28+
The TemporalBrains dataset contains a collection of functional brain connectivity networks from 1000 subjects obtained from resting-state functional MRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation).
29+
Functional connectivity is defined as the temporal dependence of neuronal activation patterns of anatomically separated brain regions.
30+
31+
The graph nodes represent brain regions and their number is fixed at 102 for each of the 27 snapshots, while the edges, representing functional connectivity, change over time.
32+
For each snapshot, the feature of a node represents the average activation of the node during that snapshot.
33+
Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+).
34+
The network's edge weights are binarized, and the threshold is set to 0.6 by default.
35+
36+
````julia
37+
brain_dataset = TemporalBrains()
38+
````
39+
40+
````
41+
dataset TemporalBrains:
42+
graphs => 1000-element Vector{MLDatasets.TemporalSnapshotsGraph}
43+
````
44+
45+
After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format.
46+
So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model.
47+
48+
````julia
49+
function data_loader(brain_dataset)
50+
graphs = brain_dataset.graphs
51+
dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs))
52+
for i in 1:length(graphs)
53+
graph = graphs[i]
54+
dataset[i] = TemporalSnapshotsGNNGraph(GNNGraphs.mlgraph2gnngraph.(graph.snapshots))
55+
# Add graph and node features
56+
for t in 1:27
57+
s = dataset[i].snapshots[t]
58+
s.ndata.x = [I(102); s.ndata.x']
59+
end
60+
dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"]))
61+
end
62+
# Split the dataset into a 80% training set and a 20% test set
63+
train_loader = dataset[1:200]
64+
test_loader = dataset[201:250]
65+
return train_loader, test_loader
66+
end
67+
````
68+
69+
````
70+
data_loader (generic function with 1 method)
71+
````
72+
73+
The first part of the `data_loader` function calls the `mlgraph2gnngraph` function for each snapshot, which takes the graph and converts it to a `GNNGraph`. The vector of `GNNGraph`s is then rewritten to a `TemporalSnapshotsGNNGraph`.
74+
75+
The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph (in this case we directly use the identity matrix) and appends the mean activation of the node of the snapshot (which is contained in the vector `dataset[i].snapshots[t].ndata.x`, where `i` is the index indicating the subject and `t` is the snapshot). For the graph feature, it adds the one-hot encoding of gender.
76+
77+
The last part splits the dataset.
78+
79+
## Model
80+
81+
We now implement a simple model that takes a `TemporalSnapshotsGNNGraph` as input.
82+
It consists of a `GINConv` applied independently to each snapshot, a `GlobalPool` to get an embedding for each snapshot, a pooling on the time dimension to get an embedding for the whole temporal graph, and finally a `Dense` layer.
83+
84+
First, we start by adapting the `GlobalPool` to the `TemporalSnapshotsGNNGraphs`.
85+
86+
````julia
87+
function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
88+
h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)]
89+
sze = size(h[1])
90+
reshape(reduce(hcat, h), sze[1], length(h))
91+
end
92+
````
93+
94+
Then we implement the constructor of the model, which we call `GenderPredictionModel`, and the foward pass.
95+
96+
````julia
97+
struct GenderPredictionModel
98+
gin::GINConv
99+
mlp::Chain
100+
globalpool::GlobalPool
101+
dense::Dense
102+
end
103+
104+
Flux.@layer GenderPredictionModel
105+
106+
function GenderPredictionModel(; nfeatures = 103, nhidden = 128, σ = relu)
107+
mlp = Chain(Dense(nfeatures => nhidden, σ), Dense(nhidden => nhidden, σ))
108+
gin = GINConv(mlp, 0.5)
109+
globalpool = GlobalPool(mean)
110+
dense = Dense(nhidden => 2)
111+
return GenderPredictionModel(gin, mlp, globalpool, dense)
112+
end
113+
114+
function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph)
115+
h = m.gin(g, g.ndata.x)
116+
h = m.globalpool(g, h)
117+
h = mean(h, dims=2)
118+
return m.dense(h)
119+
end
120+
````
121+
122+
## Training
123+
124+
We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the `logitbinarycrossentropy` as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format.
125+
The accuracy expresses the number of correct classifications.
126+
127+
````julia
128+
lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y);
129+
130+
function eval_loss_accuracy(model, data_loader)
131+
error = mean([lossfunction(model(g), g.tgdata.g) for g in data_loader])
132+
acc = mean([round(100 * mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g)); digits = 2) for g in data_loader])
133+
return (loss = error, acc = acc)
134+
end
135+
136+
function train(dataset)
137+
device = gpu_device()
138+
139+
function report(epoch)
140+
train_loss, train_acc = eval_loss_accuracy(model, train_loader)
141+
test_loss, test_acc = eval_loss_accuracy(model, test_loader)
142+
println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))")
143+
return (train_loss, train_acc, test_loss, test_acc)
144+
end
145+
146+
model = GenderPredictionModel() |> device
147+
148+
opt = Flux.setup(Adam(1.0f-3), model)
149+
150+
train_loader, test_loader = data_loader(dataset)
151+
train_loader = train_loader |> device
152+
test_loader = test_loader |> device
153+
154+
report(0)
155+
for epoch in 1:100
156+
for g in train_loader
157+
grads = Flux.gradient(model) do model
158+
= model(g)
159+
lossfunction(vec(ŷ), g.tgdata.g)
160+
end
161+
Flux.update!(opt, model, grads[1])
162+
end
163+
if epoch % 10 == 0
164+
report(epoch)
165+
end
166+
end
167+
return model
168+
end
169+
170+
171+
train(brain_dataset);
172+
173+
# Conclusions
174+
````
175+
176+
````
177+
Epoch: 0 (train_loss = 0.80321693f0, train_acc = 50.5) (test_loss = 0.79863846f0, test_acc = 60.0)
178+
Epoch: 10 (train_loss = 0.61757874f0, train_acc = 63.5) (test_loss = 0.6142881f0, test_acc = 72.0)
179+
Epoch: 20 (train_loss = 0.50907505f0, train_acc = 74.0) (test_loss = 0.646904f0, test_acc = 60.0)
180+
Epoch: 30 (train_loss = 0.35090268f0, train_acc = 81.0) (test_loss = 0.65224814f0, test_acc = 60.0)
181+
Epoch: 40 (train_loss = 0.13825743f0, train_acc = 97.0) (test_loss = 0.58508986f0, test_acc = 74.0)
182+
Epoch: 50 (train_loss = 0.44244948f0, train_acc = 77.0) (test_loss = 1.5108807f0, test_acc = 62.0)
183+
Epoch: 60 (train_loss = 0.033900682f0, train_acc = 99.5) (test_loss = 0.593368f0, test_acc = 78.0)
184+
Epoch: 70 (train_loss = 0.04119176f0, train_acc = 99.5) (test_loss = 0.4229265f0, test_acc = 84.0)
185+
Epoch: 80 (train_loss = 0.018655278f0, train_acc = 99.5) (test_loss = 0.5038431f0, test_acc = 88.0)
186+
Epoch: 90 (train_loss = 0.0074938983f0, train_acc = 100.0) (test_loss = 0.5612772f0, test_acc = 88.0)
187+
Epoch: 100 (train_loss = 0.021453373f0, train_acc = 99.5) (test_loss = 0.4984316f0, test_acc = 84.0)
188+
189+
````
190+
191+
In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 80%, but can be improved by fine-tuning the parameters and training on more data.
192+
193+
---
194+
195+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
196+

GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/temporal_graph_classification.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ using LinearAlgebra
1616
using MLDatasets: TemporalBrains
1717
using CUDA # comment out if you don't have a CUDA GPU
1818

19+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
20+
Random.seed!(17); # for reproducibility
21+
22+
1923
# ## Dataset: TemporalBrains
2024
# The TemporalBrains dataset contains a collection of functional brain connectivity networks from 1000 subjects obtained from resting-state functional MRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation).
2125
# Functional connectivity is defined as the temporal dependence of neuronal activation patterns of anatomically separated brain regions.
@@ -36,15 +40,15 @@ function data_loader(brain_dataset)
3640
dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs))
3741
for i in 1:length(graphs)
3842
graph = graphs[i]
39-
dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(graph.snapshots))
40-
# Add graph and node features
43+
dataset[i] = TemporalSnapshotsGNNGraph(GNNGraphs.mlgraph2gnngraph.(graph.snapshots))
44+
## Add graph and node features
4145
for t in 1:27
4246
s = dataset[i].snapshots[t]
4347
s.ndata.x = [I(102); s.ndata.x']
4448
end
4549
dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"]))
4650
end
47-
# Split the dataset into a 80% training set and a 20% test set
51+
## Split the dataset into a 80% training set and a 20% test set
4852
train_loader = dataset[1:200]
4953
test_loader = dataset[201:250]
5054
return train_loader, test_loader
@@ -143,8 +147,8 @@ function train(dataset)
143147
end
144148

145149

146-
train(brain_dataset)
150+
train(brain_dataset);
147151

148152
## Conclusions
149153
#
150-
# In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 75-80%, but can be improved by fine-tuning the parameters and training on more data.
154+
# In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 80%, but can be improved by fine-tuning the parameters and training on more data.

0 commit comments

Comments
 (0)