Skip to content

Commit e72ad31

Browse files
committed
Add Lux training example
1 parent db92ade commit e72ad31

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

GNNLux/docs/src/index.md

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,31 @@ Let's give a brief overview of the package by solving a graph regression problem
88

99
### Data preparation
1010

11-
We create a dataset consisting in multiple random graphs and associated data features.
11+
We generate a dataset of multiple random graphs with associated data features, then split it into training and testing sets.
1212

1313
```julia
1414
using GNNLux, Lux, Statistics, MLUtils, Random
1515
using Zygote, Optimisers
1616

17+
rng = Random.default_rng()
18+
1719
all_graphs = GNNGraph[]
1820

1921
for _ in 1:1000
20-
g = rand_graph(10, 40,
21-
ndata=(; x = randn(Float32, 16,10)), # Input node features
22-
gdata=(; y = randn(Float32))) # Regression target
22+
g = rand_graph(rng, 10, 40,
23+
ndata=(; x = randn(rng, Float32, 16,10)), # Input node features
24+
gdata=(; y = randn(rng, Float32))) # Regression target
2325
push!(all_graphs, g)
2426
end
2527

2628
train_graphs, test_graphs = MLUtils.splitobs(all_graphs, at=0.8)
29+
```
2730

28-
# g = rand_graph(10, 40, ndata=(; x = randn(Float32, 16,10)), gdata=(; y = randn(Float32)))
31+
### Model building
2932

30-
rng = Random.default_rng()
33+
We concisely define our model as a [`GNNLux.GNNChain`](@ref) containing two graph convolutional layers and initialize the model's parameters and state.
3134

35+
```julia
3236
model = GNNChain(GCNConv(16 => 64),
3337
x -> relu.(x),
3438
Dropout(0.6),
@@ -37,14 +41,18 @@ model = GNNChain(GCNConv(16 => 64),
3741
Dense(64, 1))
3842

3943
ps, st = LuxCore.setup(rng, model)
44+
```
45+
### Training
46+
47+
Finally, we use a standard Lux training pipeline to fit our dataset.
4048

41-
function custom_loss(model, ps,st,tuple)
49+
```julia
50+
function custom_loss(model, ps, st, tuple)
4251
g,x,y = tuple
4352
y_pred,st = model(g, x, ps, st)
4453
return MSELoss()(y_pred, y), (layers = st,), 0
4554
end
4655

47-
4856
function train_model!(model, ps, st, train_graphs, test_graphs)
4957
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
5058
train_loss=0
@@ -53,8 +61,10 @@ function train_model!(model, ps, st, train_graphs, test_graphs)
5361
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.x, g.y), train_state)
5462
train_loss += loss
5563
end
64+
5665
train_loss = train_loss/length(train_graphs)
57-
if iter % 10 == 0 || iter == 100
66+
67+
if iter % 10 == 0
5868
st_ = Lux.testmode(train_state.states)
5969
test_loss =0
6070
for g in test_graphs
@@ -63,11 +73,13 @@ function train_model!(model, ps, st, train_graphs, test_graphs)
6373
test_loss += MSELoss()(g.y,ŷ)
6474
end
6575
test_loss = test_loss/length(test_graphs)
76+
6677
@info (; iter, train_loss, test_loss)
6778
end
6879
end
6980

7081
return model, ps, st
7182
end
7283

73-
train_model!(model, ps, st, train_graphs, test_graphs)
84+
train_model!(model, ps, st, train_graphs, test_graphs)
85+
```

0 commit comments

Comments
 (0)