Skip to content

Commit 0196897

Browse files
authored
Add GNNLux training example in docs (#521)
* Lux training * ok works * Second version * Add Lux training example
1 parent 7a247be commit 0196897

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

GNNLux/docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ makedocs(;
2424
"API Reference" => [
2525
"Basic" => "api/basic.md",
2626
"Convolutional layers" => "api/conv.md",
27-
"Temporal Convolutional layers" => "api/temporalconv.md",]]
27+
"Temporal Convolutional layers" => "api/temporalconv.md",],
28+
]
2829
)
2930

3031

GNNLux/docs/src/index.md

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,84 @@
22

33
GNNLux.jl is a work-in-progress package that implements stateless graph convolutional layers, fully compatible with the [Lux.jl](https://lux.csail.mit.edu/stable/) machine learning framework. It is built on top of the GNNGraphs.jl, GNNlib.jl, and Lux.jl packages.
44

5-
The full documentation will be available soon.
5+
## Package overview
6+
7+
Let's give a brief overview of the package by solving a graph regression problem with synthetic data.
8+
9+
### Data preparation
10+
11+
We generate a dataset of multiple random graphs with associated data features, then split it into training and testing sets.
12+
13+
```julia
14+
using GNNLux, Lux, Statistics, MLUtils, Random
15+
using Zygote, Optimisers
16+
17+
rng = Random.default_rng()
18+
19+
all_graphs = GNNGraph[]
20+
21+
for _ in 1:1000
22+
g = rand_graph(rng, 10, 40,
23+
ndata=(; x = randn(rng, Float32, 16,10)), # Input node features
24+
gdata=(; y = randn(rng, Float32))) # Regression target
25+
push!(all_graphs, g)
26+
end
27+
28+
train_graphs, test_graphs = MLUtils.splitobs(all_graphs, at=0.8)
29+
```
30+
31+
### Model building
32+
33+
We concisely define our model as a [`GNNLux.GNNChain`](@ref) containing two graph convolutional layers and initialize the model's parameters and state.
34+
35+
```julia
36+
model = GNNChain(GCNConv(16 => 64),
37+
x -> relu.(x),
38+
Dropout(0.6),
39+
GCNConv(64 => 64, relu),
40+
x -> mean(x, dims=2),
41+
Dense(64, 1))
42+
43+
ps, st = LuxCore.setup(rng, model)
44+
```
45+
### Training
46+
47+
Finally, we use a standard Lux training pipeline to fit our dataset.
48+
49+
```julia
50+
function custom_loss(model, ps, st, tuple)
51+
g,x,y = tuple
52+
y_pred,st = model(g, x, ps, st)
53+
return MSELoss()(y_pred, y), (layers = st,), 0
54+
end
55+
56+
function train_model!(model, ps, st, train_graphs, test_graphs)
57+
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
58+
train_loss=0
59+
for iter in 1:100
60+
for g in train_graphs
61+
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.x, g.y), train_state)
62+
train_loss += loss
63+
end
64+
65+
train_loss = train_loss/length(train_graphs)
66+
67+
if iter % 10 == 0
68+
st_ = Lux.testmode(train_state.states)
69+
test_loss =0
70+
for g in test_graphs
71+
ŷ, st_ = model(g, g.x, train_state.parameters, st_)
72+
st_ = (layers = st_,)
73+
test_loss += MSELoss()(g.y,ŷ)
74+
end
75+
test_loss = test_loss/length(test_graphs)
76+
77+
@info (; iter, train_loss, test_loss)
78+
end
79+
end
80+
81+
return model, ps, st
82+
end
83+
84+
train_model!(model, ps, st, train_graphs, test_graphs)
85+
```

0 commit comments

Comments
 (0)