Skip to content

Commit 5250c4d

Browse files
authored
Add example for gender classification on TemporalBrains dataset (#397)
* Add example gender classification on fMRI data * Improved description * Fix typo * Done
1 parent 50ebac3 commit 5250c4d

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Example of graph classification when graphs are temporal and modeled as `TemporalSnapshotsGNNGraphs'.
2+
# In this code, we train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals.
3+
# The dataset used is the TemporalBrains dataset from the MLDataset.jl package, and the accuracy achieved with the model reaches 65-70% (it can be improved by fine-tuning the parameters of the model).
4+
# Author: Aurora Rossi
5+
6+
# Load packages
7+
using Flux
8+
using Flux.Losses: mae
9+
using GraphNeuralNetworks
10+
using CUDA
11+
using Statistics, Random
12+
using LinearAlgebra
13+
using MLDatasets
14+
CUDA.allowscalar(false)
15+
16+
# Load data
17+
MLdataset = TemporalBrains()
18+
graphs = MLdataset.graphs
19+
20+
# Function to transform the graphs from the MLDatasets format to the TemporalSnapshotsGNNGraph format
21+
# and split the dataset into a training and a test set
22+
function data_loader(graphs)
23+
dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs))
24+
for i in 1:length(graphs)
25+
gr = graphs[i]
26+
dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(gr.snapshots))
27+
for t in 1:27
28+
dataset[i].snapshots[t].ndata.x = reduce(
29+
vcat, [I(102), dataset[i].snapshots[t].ndata.x'])
30+
end
31+
dataset[i].tgdata.g = Float32.(Array(Flux.onehot(gr.graph_data.g, ["F", "M"])))
32+
end
33+
# Split the dataset into a 80% training set and a 20% test set
34+
train_loader = dataset[1:800]
35+
test_loader = dataset[801:1000]
36+
return train_loader, test_loader
37+
end
38+
39+
# Arguments for the train function
40+
Base.@kwdef mutable struct Args
41+
η = 1.0f-3 # learning rate
42+
epochs = 200 # number of epochs
43+
seed = -5 # set seed > 0 for reproducibility
44+
usecuda = true # if true use cuda (if available)
45+
nhidden = 128 # dimension of hidden features
46+
infotime = 10 # report every `infotime` epochs
47+
end
48+
49+
# Adapt GlobalPool to work with TemporalSnapshotsGNNGraph
50+
function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
51+
h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)]
52+
sze = size(h[1])
53+
reshape(reduce(hcat, h), sze[1], length(h))
54+
end
55+
56+
# Define the model
57+
struct GenderPredictionModel
58+
gin::GINConv
59+
mlp::Chain
60+
globalpool::GlobalPool
61+
f::Function
62+
dense::Dense
63+
end
64+
65+
Flux.@functor GenderPredictionModel
66+
67+
function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu)
68+
mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation))
69+
gin = GINConv(mlp, 0.5)
70+
globalpool = GlobalPool(mean)
71+
f = x -> mean(x, dims = 2)
72+
dense = Dense(nhidden, 2)
73+
GenderPredictionModel(gin, mlp, globalpool, f, dense)
74+
end
75+
76+
function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph)
77+
h = m.gin(g, g.ndata.x)
78+
h = m.globalpool(g, h)
79+
h = m.f(h)
80+
m.dense(h)
81+
end
82+
83+
# Train the model
84+
85+
function train(graphs; kws...)
86+
args = Args(; kws...)
87+
args.seed > 0 && Random.seed!(args.seed)
88+
89+
if args.usecuda && CUDA.functional()
90+
my_device = gpu
91+
args.seed > 0 && CUDA.seed!(args.seed)
92+
@info "Training on GPU"
93+
else
94+
my_device = cpu
95+
@info "Training on CPU"
96+
end
97+
98+
lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y) |> my_device
99+
100+
function eval_loss_accuracy(model, data_loader)
101+
error = mean([lossfunction(model(g), gpu(g.tgdata.g)) for g in data_loader])
102+
acc = mean([round(
103+
100 *
104+
mean(Flux.onecold(model(g)) .== Flux.onecold(gpu(g.tgdata.g)));
105+
digits = 2) for g in data_loader])
106+
return (loss = error, acc = acc)
107+
end
108+
109+
function report(epoch)
110+
train_loss, train_acc = eval_loss_accuracy(model, train_loader)
111+
test_loss, test_acc = eval_loss_accuracy(model, test_loader)
112+
println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))")
113+
return (train_loss, train_acc, test_loss, test_acc)
114+
end
115+
116+
model = GenderPredictionModel() |> my_device
117+
118+
opt = Flux.setup(Adam(args.η), model)
119+
120+
train_loader, test_loader = data_loader(graphs) # it takes a while to load the data
121+
122+
train_loader = train_loader |> my_device
123+
test_loader = test_loader |> my_device
124+
125+
report(0)
126+
for epoch in 1:(args.epochs)
127+
for g in train_loader
128+
grads = Flux.gradient(model) do model
129+
= model(g)
130+
lossfunction(vec(ŷ), g.tgdata.g)
131+
end
132+
Flux.update!(opt, model, grads[1])
133+
end
134+
if args.infotime > 0 && epoch % args.infotime == 0
135+
report(epoch)
136+
end
137+
end
138+
return model
139+
end
140+
141+
model = train(graphs)

0 commit comments

Comments
 (0)