Skip to content

Commit db92ade

Browse files
committed
Second version
1 parent ac405db commit db92ade

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

GNNLux/docs/src/index.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ train_graphs, test_graphs = MLUtils.splitobs(all_graphs, at=0.8)
3030
rng = Random.default_rng()
3131

3232
model = GNNChain(GCNConv(16 => 64),
33-
x -> relu.(x),
33+
x -> relu.(x),
34+
Dropout(0.6),
3435
GCNConv(64 => 64, relu),
3536
x -> mean(x, dims=2),
3637
Dense(64, 1))
@@ -46,11 +47,13 @@ end
4647

4748
function train_model!(model, ps, st, train_graphs, test_graphs)
4849
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
49-
loss=0
50+
train_loss=0
5051
for iter in 1:100
5152
for g in train_graphs
5253
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.x, g.y), train_state)
54+
train_loss += loss
5355
end
56+
train_loss = train_loss/length(train_graphs)
5457
if iter % 10 == 0 || iter == 100
5558
st_ = Lux.testmode(train_state.states)
5659
test_loss =0
@@ -60,7 +63,7 @@ function train_model!(model, ps, st, train_graphs, test_graphs)
6063
test_loss += MSELoss()(g.y,ŷ)
6164
end
6265
test_loss = test_loss/length(test_graphs)
63-
@info (; iter, loss, test_loss)
66+
@info (; iter, train_loss, test_loss)
6467
end
6568
end
6669

0 commit comments

Comments
 (0)