Skip to content

Commit ac405db

Browse files
committed
ok works
1 parent 5a4b1e2 commit ac405db

File tree

1 file changed

+25
-37
lines changed

1 file changed

+25
-37
lines changed

GNNLux/docs/src/index.md

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ We create a dataset consisting in multiple random graphs and associated data fea
1212

1313
```julia
1414
using GNNLux, Lux, Statistics, MLUtils, Random
15-
using Zygote, Optimizers
15+
using Zygote, Optimisers
1616

1717
all_graphs = GNNGraph[]
1818

@@ -22,61 +22,49 @@ for _ in 1:1000
2222
gdata=(; y = randn(Float32))) # Regression target
2323
push!(all_graphs, g)
2424
end
25-
```
2625

27-
### Model building
26+
train_graphs, test_graphs = MLUtils.splitobs(all_graphs, at=0.8)
2827

29-
We concisely define our model as a [`GNNLux.GNNChain`](@ref) containing two graph convolutional layers. If CUDA is available, our model will live on the gpu.
28+
# g = rand_graph(10, 40, ndata=(; x = randn(Float32, 16,10)), gdata=(; y = randn(Float32)))
3029

31-
```julia
32-
device = CUDA.functional() ? Lux.gpu_device() : Lux.cpu_device()
3330
rng = Random.default_rng()
3431

3532
model = GNNChain(GCNConv(16 => 64),
3633
x -> relu.(x),
3734
GCNConv(64 => 64, relu),
38-
GlobalMeanPool(), # Aggregate node-wise features into graph-wise features
35+
x -> mean(x, dims=2),
3936
Dense(64, 1))
4037

4138
ps, st = LuxCore.setup(rng, model)
42-
```
43-
44-
### Training
4539

46-
47-
```julia
48-
train_graphs, test_graphs = MLUtils.splitobs(all_graphs, at=0.8)
49-
50-
train_loader = MLUtils.DataLoader(train_graphs,
51-
batchsize=32, shuffle=true, collate=true)
52-
test_loader = MLUtils.DataLoader(test_graphs,
53-
batchsize=32, shuffle=false, collate=true)
54-
55-
for epoch in 1:100
56-
for g in train_loader
57-
g = g |> device
58-
grad = gradient(model -> loss(model, g), model)
59-
Flux.update!(opt, model, grad[1])
60-
end
61-
62-
@info (; epoch, train_loss=loss(model, train_loader), test_loss=loss(model, test_loader))
40+
function custom_loss(model, ps,st,tuple)
41+
g,x,y = tuple
42+
y_pred,st = model(g, x, ps, st)
43+
return MSELoss()(y_pred, y), (layers = st,), 0
6344
end
6445

65-
function train_model!(model, ps, st, train_loader)
66-
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))
6746

68-
for iter in 1:1000
69-
for g in train_loader
70-
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), MSELoss(),
71-
((g, g.x)...,g.y), train_state)
72-
if iter % 100 == 1 || iter == 1000
73-
@info "Iteration: %04d \t Loss: %10.9g\n" iter loss
47+
function train_model!(model, ps, st, train_graphs, test_graphs)
48+
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
49+
loss=0
50+
for iter in 1:100
51+
for g in train_graphs
52+
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.x, g.y), train_state)
53+
end
54+
if iter % 10 == 0 || iter == 100
55+
st_ = Lux.testmode(train_state.states)
56+
test_loss =0
57+
for g in test_graphs
58+
ŷ, st_ = model(g, g.x, train_state.parameters, st_)
59+
st_ = (layers = st_,)
60+
test_loss += MSELoss()(g.y,ŷ)
7461
end
62+
test_loss = test_loss/length(test_graphs)
63+
@info (; iter, loss, test_loss)
7564
end
7665
end
7766

7867
return model, ps, st
7968
end
8069

81-
train_model!(model, ps, st, train_loader)
82-
```
70+
train_model!(model, ps, st, train_graphs, test_graphs)

0 commit comments

Comments
 (0)