Skip to content

Commit 5a4b1e2

Browse files
committed
Lux training
1 parent e1910ca commit 5a4b1e2

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

GNNLux/docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ makedocs(;
2424
"API Reference" => [
2525
"Basic" => "api/basic.md",
2626
"Convolutional layers" => "api/conv.md",
27-
"Temporal Convolutional layers" => "api/temporalconv.md",]]
27+
"Temporal Convolutional layers" => "api/temporalconv.md",],
28+
]
2829
)
2930

3031

GNNLux/docs/src/index.md

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,81 @@
22

33
GNNLux.jl is a work-in-progress package that implements stateless graph convolutional layers, fully compatible with the [Lux.jl](https://lux.csail.mit.edu/stable/) machine learning framework. It is built on top of the GNNGraphs.jl, GNNlib.jl, and Lux.jl packages.
44

5-
The full documentation will be available soon.
5+
## Package overview
6+
7+
Let's give a brief overview of the package by solving a graph regression problem with synthetic data.
8+
9+
### Data preparation
10+
11+
We create a dataset consisting in multiple random graphs and associated data features.
12+
13+
```julia
14+
using GNNLux, Lux, Statistics, MLUtils, Random
15+
using Zygote, Optimizers
16+
17+
all_graphs = GNNGraph[]
18+
19+
for _ in 1:1000
20+
g = rand_graph(10, 40,
21+
ndata=(; x = randn(Float32, 16,10)), # Input node features
22+
gdata=(; y = randn(Float32))) # Regression target
23+
push!(all_graphs, g)
24+
end
25+
```
26+
27+
### Model building
28+
29+
We concisely define our model as a [`GNNLux.GNNChain`](@ref) containing two graph convolutional layers. If CUDA is available, our model will live on the gpu.
30+
31+
```julia
32+
device = CUDA.functional() ? Lux.gpu_device() : Lux.cpu_device()
33+
rng = Random.default_rng()
34+
35+
model = GNNChain(GCNConv(16 => 64),
36+
x -> relu.(x),
37+
GCNConv(64 => 64, relu),
38+
GlobalMeanPool(), # Aggregate node-wise features into graph-wise features
39+
Dense(64, 1))
40+
41+
ps, st = LuxCore.setup(rng, model)
42+
```
43+
44+
### Training
45+
46+
47+
```julia
48+
train_graphs, test_graphs = MLUtils.splitobs(all_graphs, at=0.8)
49+
50+
train_loader = MLUtils.DataLoader(train_graphs,
51+
batchsize=32, shuffle=true, collate=true)
52+
test_loader = MLUtils.DataLoader(test_graphs,
53+
batchsize=32, shuffle=false, collate=true)
54+
55+
for epoch in 1:100
56+
for g in train_loader
57+
g = g |> device
58+
grad = gradient(model -> loss(model, g), model)
59+
Flux.update!(opt, model, grad[1])
60+
end
61+
62+
@info (; epoch, train_loss=loss(model, train_loader), test_loss=loss(model, test_loader))
63+
end
64+
65+
function train_model!(model, ps, st, train_loader)
66+
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))
67+
68+
for iter in 1:1000
69+
for g in train_loader
70+
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), MSELoss(),
71+
((g, g.x)...,g.y), train_state)
72+
if iter % 100 == 1 || iter == 1000
73+
@info "Iteration: %04d \t Loss: %10.9g\n" iter loss
74+
end
75+
end
76+
end
77+
78+
return model, ps, st
79+
end
80+
81+
train_model!(model, ps, st, train_loader)
82+
```

0 commit comments

Comments
 (0)