Skip to content

Commit f05af33

Browse files
authored
Add TemporalSnapshotsGNNgraph classification tutorial (#408)
* Create file * Save intro * First draft * Improvements * Almost ready * Last comments * Add front matter * Add brain gif * Fix frontmatter * Better gif * Add tutorial md * Some improvements * Fix function signature * Add cuDNN * Revert "Add cuDNN" This reverts commit 726b6d8. * New version * Update new version supporting CUDA
1 parent 0c641a2 commit f05af33

File tree

3 files changed

+1904
-0
lines changed

3 files changed

+1904
-0
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
```@raw html
2+
<style>
3+
table {
4+
display: table !important;
5+
margin: 2rem auto !important;
6+
border-top: 2pt solid rgba(0,0,0,0.2);
7+
border-bottom: 2pt solid rgba(0,0,0,0.2);
8+
}
9+
10+
pre, div {
11+
margin-top: 1.4rem !important;
12+
margin-bottom: 1.4rem !important;
13+
}
14+
15+
.code-output {
16+
padding: 0.7rem 0.5rem !important;
17+
}
18+
19+
.admonition-body {
20+
padding: 0em 1.25em !important;
21+
}
22+
</style>
23+
24+
<!-- PlutoStaticHTML.Begin -->
25+
<!--
26+
# This information is used for caching.
27+
[PlutoStaticHTML.State]
28+
input_sha = "d90c787ba47953887e638f9c1a1d0446ad07c5f4b8a640ef5b0c846f0ead6598"
29+
julia_version = "1.9.1"
30+
-->
31+
32+
<div class="markdown"><p>In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying.</p><p>We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. Given the large amount of data, we will implement the training so that it can also run on the GPU.</p></div>
33+
34+
35+
```
36+
## Import
37+
```@raw html
38+
<div class="markdown">
39+
<p>We start by importing the necessary libraries. We use <code>GraphNeuralNetworks.jl</code>, <code>Flux.jl</code> and <code>MLDatasets.jl</code>, among others.</p></div>
40+
41+
<pre class='language-julia'><code class='language-julia'>begin
42+
using Flux
43+
using GraphNeuralNetworks
44+
using Statistics, Random
45+
using LinearAlgebra
46+
using MLDatasets: TemporalBrains
47+
using CUDA
48+
using cuDNN
49+
end</code></pre>
50+
51+
52+
53+
```
54+
## Dataset: TemporalBrains
55+
```@raw html
56+
<div class="markdown">
57+
<p>The TemporalBrains dataset contains a collection of functional brain connectivity networks from 1000 subjects obtained from resting-state functional MRI data from the <a href="https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation">Human Connectome Project (HCP)</a>. Functional connectivity is defined as the temporal dependence of neuronal activation patterns of anatomically separated brain regions.</p><p>The graph nodes represent brain regions and their number is fixed at 102 for each of the 27 snapshots, while the edges, representing functional connectivity, change over time. For each snapshot, the feature of a node represents the average activation of the node during that snapshot. Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+). The network's edge weights are binarized, and the threshold is set to 0.6 by default.</p></div>
58+
59+
<pre class='language-julia'><code class='language-julia'>brain_dataset = TemporalBrains()</code></pre>
60+
<pre class="code-output documenter-example-output" id="var-brain_dataset">dataset TemporalBrains:
61+
graphs =&gt; 1000-element Vector{MLDatasets.TemporalSnapshotsGraph}</pre>
62+
63+
64+
<div class="markdown"><p>After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the <code>TemporalSnapshotsGNNGraph</code> format. So we create a function called <code>data_loader</code> that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model.</p></div>
65+
66+
<pre class='language-julia'><code class='language-julia'>function data_loader(brain_dataset)
67+
graphs = brain_dataset.graphs
68+
dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs))
69+
for i in 1:length(graphs)
70+
graph = graphs[i]
71+
dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(graph.snapshots))
72+
# Add graph and node features
73+
for t in 1:27
74+
s = dataset[i].snapshots[t]
75+
s.ndata.x = [I(102); s.ndata.x']
76+
end
77+
dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"]))
78+
end
79+
# Split the dataset into a 80% training set and a 20% test set
80+
train_loader = dataset[1:200]
81+
test_loader = dataset[201:250]
82+
return train_loader, test_loader
83+
end;</code></pre>
84+
85+
86+
87+
<div class="markdown"><p>The first part of the <code>data_loader</code> function calls the <code>mlgraph2gnngraph</code> function for each snapshot, which takes the graph and converts it to a <code>GNNGraph</code>. The vector of <code>GNNGraph</code>s is then rewritten to a <code>TemporalSnapshotsGNNGraph</code>.</p><p>The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph (in this case we directly use the identity matrix) and appends the mean activation of the node of the snapshot (which is contained in the vector <code>dataset[i].snapshots[t].ndata.x</code>, where <code>i</code> is the index indicating the subject and <code>t</code> is the snapshot). For the graph feature, it adds the one-hot encoding of gender.</p><p>The last part splits the dataset.</p></div>
88+
89+
90+
```
91+
## Model
92+
```@raw html
93+
<div class="markdown">
94+
<p>We now implement a simple model that takes a <code>TemporalSnapshotsGNNGraph</code> as input. It consists of a <code>GINConv</code> applied independently to each snapshot, a <code>GlobalPool</code> to get an embedding for each snapshot, a pooling on the time dimension to get an embedding for the whole temporal graph, and finally a <code>Dense</code> layer.</p><p>First, we start by adapting the <code>GlobalPool</code> to the <code>TemporalSnapshotsGNNGraphs</code>.</p></div>
95+
96+
<pre class='language-julia'><code class='language-julia'>function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
97+
h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)]
98+
sze = size(h[1])
99+
reshape(reduce(hcat, h), sze[1], length(h))
100+
end</code></pre>
101+
102+
103+
104+
<div class="markdown"><p>Then we implement the constructor of the model, which we call <code>GenderPredictionModel</code>, and the foward pass.</p></div>
105+
106+
<pre class='language-julia'><code class='language-julia'>begin
107+
struct GenderPredictionModel
108+
gin::GINConv
109+
mlp::Chain
110+
globalpool::GlobalPool
111+
f::Function
112+
dense::Dense
113+
end
114+
115+
Flux.@functor GenderPredictionModel
116+
117+
function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu)
118+
mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation))
119+
gin = GINConv(mlp, 0.5)
120+
globalpool = GlobalPool(mean)
121+
f = x -&gt; mean(x, dims = 2)
122+
dense = Dense(nhidden, 2)
123+
GenderPredictionModel(gin, mlp, globalpool, f, dense)
124+
end
125+
126+
function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph)
127+
h = m.gin(g, g.ndata.x)
128+
h = m.globalpool(g, h)
129+
h = m.f(h)
130+
m.dense(h)
131+
end
132+
133+
end</code></pre>
134+
135+
136+
137+
```
138+
## Training
139+
```@raw html
140+
<div class="markdown">
141+
<p>We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the <code>logitbinarycrossentropy</code> as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format. The accuracy expresses the number of correct classifications. </p></div>
142+
143+
<pre class='language-julia'><code class='language-julia'>lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y);</code></pre>
144+
145+
146+
<pre class='language-julia'><code class='language-julia'>function eval_loss_accuracy(model, data_loader)
147+
error = mean([lossfunction(model(g), g.tgdata.g) for g in data_loader])
148+
acc = mean([round(100 * mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g)); digits = 2) for g in data_loader])
149+
return (loss = error, acc = acc)
150+
end;</code></pre>
151+
152+
153+
<pre class='language-julia'><code class='language-julia'>function train(dataset; usecuda::Bool, kws...)
154+
155+
if usecuda && CUDA.functional() #check if GPU is available
156+
my_device = gpu
157+
@info "Training on GPU"
158+
else
159+
my_device = cpu
160+
@info "Training on CPU"
161+
end
162+
163+
function report(epoch)
164+
train_loss, train_acc = eval_loss_accuracy(model, train_loader)
165+
test_loss, test_acc = eval_loss_accuracy(model, test_loader)
166+
println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))")
167+
return (train_loss, train_acc, test_loss, test_acc)
168+
end
169+
170+
model = GenderPredictionModel() |&gt; my_device
171+
172+
opt = Flux.setup(Adam(1.0f-3), model)
173+
174+
train_loader, test_loader = data_loader(dataset)
175+
train_loader = train_loader |&gt; my_device
176+
test_loader = test_loader |&gt; my_device
177+
178+
report(0)
179+
for epoch in 1:100
180+
for g in train_loader
181+
grads = Flux.gradient(model) do model
182+
ŷ = model(g)
183+
lossfunction(vec(ŷ), g.tgdata.g)
184+
end
185+
Flux.update!(opt, model, grads[1])
186+
end
187+
if epoch % 10 == 0
188+
report(epoch)
189+
end
190+
end
191+
return model
192+
end;
193+
</code></pre>
194+
195+
196+
<pre class='language-julia'><code class='language-julia'>train(brain_dataset; usecuda = true)</code></pre>
197+
<pre class="code-output documenter-example-output" id="var-hash305203">GenderPredictionModel(GINConv(Chain(Dense(103 =&gt; 128, relu), Dense(128 =&gt; 128, relu)), 0.5), Chain(Dense(103 =&gt; 128, relu), Dense(128 =&gt; 128, relu)), GlobalPool{typeof(mean)}(Statistics.mean), var"#4#5"(), Dense(128 =&gt; 2))</pre>
198+
199+
200+
<div class="markdown"><p>We set up the training on the GPU because training takes a lot of time, especially when working on the CPU.</p></div>
201+
202+
203+
```
204+
## Conclusions
205+
```@raw html
206+
<div class="markdown">
207+
<p>In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 75-80%, but can be improved by fine-tuning the parameters and training on more data.</p></div>
208+
209+
<!-- PlutoStaticHTML.End -->
210+
```
211+
1.58 MB
Loading

0 commit comments

Comments
 (0)