Skip to content

Commit acd68ad

Browse files
Creating File tutorial for running GNNLux temporal Graph Layer
1 parent 58fcd7d commit acd68ad

File tree

3 files changed

+262
-18
lines changed

3 files changed

+262
-18
lines changed

GNNLux/docs/src_tutorials/graph_classification.jl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,22 @@
1313
# The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl.
1414
# Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**:
1515

16-
using Lux, GNNLux
17-
using MLDatasets, MLUtils
16+
using Lux
17+
using GNNLux
18+
using MLDatasets
19+
using MLUtils
1820
using LinearAlgebra, Random, Statistics
1921
using Zygote, Optimisers, OneHotArrays
2022

23+
24+
struct GlobalPool{F} <: GNNLayer
25+
aggr::F
26+
end
27+
28+
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st
29+
30+
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
31+
2132
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
2233
rng = Random.seed!(42); # for reproducibility
2334

@@ -107,19 +118,20 @@ MLUtils.batch(vec_gs)
107118

108119
# The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:
109120

110-
function create_model(nin, nh, nout)
111-
GNNChain(GCNConv(nin => nh, relu),
112-
GCNConv(nh => nh, relu),
113-
GCNConv(nh => nh),
121+
# Then use it in the model
122+
function create_model_graphconv(nin, nh, nout)
123+
GNNChain(GraphConv(nin => nh, relu),
124+
GraphConv(nh => nh, relu),
125+
GraphConv(nh => nh),
114126
GlobalPool(mean),
115127
Dropout(0.5),
116128
Dense(nh, nout))
117-
end;
129+
end
118130

119131
nin = 7
120132
nh = 64
121133
nout = 2
122-
model = create_model(nin, nh, nout)
134+
model = create_model_graphconv(nin, nh, nout)
123135

124136
ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model);
125137

@@ -191,11 +203,4 @@ model, ps, st = train_model!(model, ps, st);
191203

192204
# This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl.
193205

194-
# As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`.
195-
# This should bring you close to **82% test accuracy**.
196-
197-
# ## Conclusion
198-
199-
# In this chapter, you have learned how to apply GNNs to the task of graph classification.
200-
# You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings.
201-
206+
# As an exercise, you are invited to complete the following code to the extent that it makes use of `
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
using Lux
2+
using GNNLux
3+
using MLDatasets
4+
using MLUtils
5+
using LinearAlgebra, Random, Statistics
6+
using Zygote, Optimisers, OneHotArrays
7+
using MLDatasets: TemporalBrains
8+
using GNNlib
9+
using Optimisers
10+
11+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
12+
rng = Random.seed!(42); # for reproducibility
13+
14+
brain_dataset = MLDatasets.TemporalBrains()
15+
16+
function data_loader(brain_dataset)
17+
graphs = brain_dataset.graphs
18+
dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs))
19+
for i in 1:length(graphs)
20+
graph = graphs[i]
21+
dataset[i] = TemporalSnapshotsGNNGraph(GNNGraphs.mlgraph2gnngraph.(graph.snapshots))
22+
## Add graph and node features
23+
for t in 1:27
24+
s = dataset[i].snapshots[t]
25+
s.ndata.x = Float32.([I(102); s.ndata.x'])
26+
end
27+
dataset[i].tgdata.g = Float32.(onehotbatch([graph.graph_data.g], ["F", "M"]))
28+
end
29+
30+
## Split the dataset into a 80% training set and a 20% test set
31+
train_graphs = dataset[1:200]
32+
test_graphs = dataset[201:250]
33+
34+
# Create tuples of (graph, label) for compatibility with training loop
35+
train_loader = [(g, g.tgdata.g) for g in train_graphs]
36+
test_loader = [(g, g.tgdata.g) for g in test_graphs]
37+
38+
return train_loader, test_loader
39+
end
40+
41+
struct GlobalPool{F} <: GNNLayer
42+
aggr::F
43+
end
44+
45+
# Implementation for regular GNNGraph (similar to graph_classification.jl)
46+
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st
47+
48+
# Implementation for TemporalSnapshotsGNNGraph - processes each snapshot and returns mean
49+
function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector, ps, st)
50+
h = [GNNlib.global_pool(l, g.snapshots[i], x[i]) for i in 1:g.num_snapshots]
51+
return mean(h), st
52+
end
53+
54+
55+
# Convenience method for directly creating graph-level embeddings
56+
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
57+
58+
struct GenderPredictionModel <: AbstractLuxLayer
59+
gin::GINConv
60+
mlp::Chain
61+
globalpool::GlobalPool
62+
dense::Dense
63+
end
64+
65+
# Implementation for GINConv with TemporalSnapshotsGNNGraph - non-mutating version
66+
function (l::GINConv)(g::TemporalSnapshotsGNNGraph, x::AbstractVector, ps, st)
67+
# Use map instead of preallocation and mutation
68+
results = map(1:g.num_snapshots) do i
69+
l(g.snapshots[i], x[i], ps, st)
70+
end
71+
72+
# Extract outputs and final state
73+
h = [r[1] for r in results]
74+
st_final = results[end][2] # Use the final state
75+
76+
return h, st_final
77+
end
78+
79+
# Constructor for GenderPredictionModel using Lux components
80+
function GenderPredictionModel(; nfeatures = 103, nhidden = 128, σ = relu)
81+
mlp = Chain(Dense(nfeatures => nhidden, σ), Dense(nhidden => nhidden, σ))
82+
gin = GINConv(mlp, 0.5f0)
83+
globalpool = GlobalPool(mean)
84+
dense = Dense(nhidden => 2)
85+
return GenderPredictionModel(gin, mlp, globalpool, dense)
86+
end
87+
88+
# Type-constrained forward pass
89+
function (m::GenderPredictionModel)(
90+
g::TemporalSnapshotsGNNGraph,
91+
x::AbstractVector,
92+
ps::NamedTuple,
93+
st::NamedTuple
94+
)
95+
# Now Julia will throw an error if types don't match
96+
h, st_gin = m.gin(g, x, ps.gin, st.gin)
97+
h, st_globalpool = m.globalpool(g, h, ps.globalpool, st.globalpool)
98+
output, st_dense = m.dense(h, ps.dense, st.dense)
99+
100+
st_new = (gin=st_gin, globalpool=st_globalpool, dense=st_dense)
101+
return output, st_new
102+
end
103+
104+
# Type-constrained custom loss that handles the layers wrapper
105+
function custom_loss(
106+
model::GenderPredictionModel,
107+
ps::NamedTuple,
108+
st::NamedTuple,
109+
tuple::Tuple{TemporalSnapshotsGNNGraph, AbstractVector, AbstractMatrix}
110+
)
111+
g, x, y = tuple
112+
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
113+
114+
# Check if we're dealing with a state that has the layers wrapper
115+
actual_st = if haskey(st, :layers)
116+
st.layers # Unwrap the layers to get the actual state structure
117+
else
118+
st
119+
end
120+
121+
# Ensure state is in trainmode
122+
actual_st = Lux.trainmode(actual_st)
123+
124+
# Forward pass
125+
ŷ, new_st = model(g, x, ps, actual_st)
126+
127+
# Wrap the new state back in the layers structure if needed
128+
final_st = if haskey(st, :layers)
129+
(layers = new_st,)
130+
else
131+
new_st
132+
end
133+
134+
return logitcrossentropy(ŷ, y), final_st, 0
135+
end
136+
137+
# Implement Lux interface methods for parameter and state initialization
138+
function LuxCore.initialparameters(rng::AbstractRNG, m::GenderPredictionModel)
139+
return (
140+
gin = LuxCore.initialparameters(rng, m.gin),
141+
mlp = LuxCore.initialparameters(rng, m.mlp),
142+
globalpool = LuxCore.initialparameters(rng, m.globalpool),
143+
dense = LuxCore.initialparameters(rng, m.dense)
144+
)
145+
end
146+
147+
function LuxCore.initialstates(rng::AbstractRNG, m::GenderPredictionModel)
148+
return (
149+
gin = LuxCore.initialstates(rng, m.gin),
150+
mlp = LuxCore.initialstates(rng, m.mlp),
151+
globalpool = LuxCore.initialstates(rng, m.globalpool),
152+
dense = LuxCore.initialstates(rng, m.dense)
153+
)
154+
end
155+
156+
# Initialize model and parameters
157+
model = GenderPredictionModel()
158+
ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model);
159+
160+
# Simple loss function that works with predictions and targets
161+
lossfunction(ŷ, y) = mean(-y .* log.(sigmoid.(ŷ)) - (1 .- y) .* log.(1 .- sigmoid.(ŷ)));
162+
163+
function eval_loss_accuracy(model, ps, st, data_loader)
164+
losses = []
165+
accs = []
166+
167+
for (g, y) in data_loader
168+
# Extract features from each snapshot
169+
x = [s.ndata.x for s in g.snapshots]
170+
171+
# Forward pass with Lux model
172+
ŷ, _ = model(g, x, ps, st)
173+
174+
# Calculate loss
175+
push!(losses, lossfunction(ŷ, y))
176+
177+
# Calculate accuracy
178+
pred_indices = [argmax(ŷ[:, i]) for i in 1:size(ŷ, 2)]
179+
true_indices = [argmax(y[:, i]) for i in 1:size(y, 2)]
180+
accuracy = round(100 * mean(pred_indices .== true_indices), digits=2)
181+
push!(accs, accuracy)
182+
end
183+
184+
return (loss = mean(losses), acc = mean(accs))
185+
end
186+
187+
# Train the model
188+
train_loader, test_loader = data_loader(brain_dataset)
189+
190+
for iter in 1:5
191+
for (g, y) in train_loader
192+
193+
# Use Lux training step with our custom loss
194+
_, loss, _, train_state = Lux.Training.single_train_step!(
195+
AutoZygote(),
196+
custom_loss,
197+
(g, g.ndata.x, y),
198+
train_state
199+
)
200+
end
201+
202+
report(iter)
203+
204+
# Update the global variables with latest parameters and states
205+
ps, st = train_state.parameters, train_state.states
206+
end
207+
208+
function train(model, train_loader, test_loader )
209+
train_state = Lux.Training.TrainState(model, ps, st, Adam(1e-2))
210+
function report(epoch)
211+
current_ps = train_state.parameters
212+
current_st = train_state.states
213+
train = eval_loss_accuracy(model, current_ps, current_st, train_loader)
214+
test_st = Lux.testmode(current_st)
215+
test = eval_loss_accuracy(model, current_ps, test_st, test_loader)
216+
@info (; epoch, train, test)
217+
end
218+
219+
for epoch in 1:5
220+
for (g, y) in train_loader
221+
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, (g, g.ndata.x, y), train_state)
222+
end
223+
if epoch % 1 == 0
224+
report(epoch)
225+
end
226+
end
227+
end
228+
229+
train(model, train_loader, test_loader)

GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/temporal_graph_classification.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99
#
1010
# We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others.
1111

12+
## Comments Miguel for CLaudio:
13+
# 1. Create method to check the download datasets are download correctly, if not problems may arise. This happened to me when downloading TemporalBrains dataset.
14+
15+
1216
using Flux
1317
using GraphNeuralNetworks
1418
using Statistics, Random
1519
using LinearAlgebra
1620
using MLDatasets: TemporalBrains
17-
using CUDA # comment out if you don't have a CUDA GPU
21+
using DataDeps
22+
#using CUDA # comment out if you don't have a CUDA GPU
1823

1924
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
2025
Random.seed!(17); # for reproducibility
@@ -29,7 +34,7 @@ Random.seed!(17); # for reproducibility
2934
# Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+).
3035
# The network's edge weights are binarized, and the threshold is set to 0.6 by default.
3136

32-
brain_dataset = TemporalBrains()
37+
brain_dataset = MLDatasets.TemporalBrains()
3338

3439
# After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format.
3540
# So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model. Due to computational costs, we use only 250 out of the original 1000 graphs, 200 for training and 50 for testing.
@@ -83,6 +88,11 @@ end
8388

8489
Flux.@layer GenderPredictionModel
8590

91+
function (l::GINConv)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
92+
h = [l(g[i], x[i]) for i in 1:(g.num_snapshots)]
93+
return h
94+
end
95+
8696
function GenderPredictionModel(; nfeatures = 103, nhidden = 128, σ = relu)
8797
mlp = Chain(Dense(nfeatures => nhidden, σ), Dense(nhidden => nhidden, σ))
8898
gin = GINConv(mlp, 0.5)

0 commit comments

Comments
 (0)