Skip to content

Commit b62b53b

Browse files
committed
Fix loss function name
1 parent cac6506 commit b62b53b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

GNNLux/docs/src_tutorials/node_classification.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ ps, st = Lux.setup(rng, MLP);
9696
# This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training).
9797

9898

99-
function custom_loss(model, ps, st, x)
99+
function loss(model, ps, st, x)
100100
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
101101
ŷ, st = model(x, ps, st)
102102
return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0
@@ -105,10 +105,10 @@ end
105105
function train_model!(MLP, ps, st, x, epochs)
106106
train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3))
107107
for iter in 1:epochs
108-
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, x, train_state)
108+
_, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss, x, train_state)
109109

110110
if iter % 100 == 0
111-
println("Epoch: $(iter) Loss: $(loss)")
111+
println("Epoch: $(iter) Loss: $(loss_value)")
112112
end
113113
end
114114
end
@@ -208,7 +208,7 @@ visualize_tsne(h_untrained, g.ndata.targets)
208208

209209

210210

211-
function custom_loss(gcn, ps, st, tuple)
211+
function loss(gcn, ps, st, tuple)
212212
g, x, y = tuple
213213
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
214214
ŷ, st = gcn(g, x, ps, st)
@@ -218,10 +218,10 @@ end
218218
function train_model!(gcn, ps, st, g, x, y)
219219
train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2))
220220
for iter in 1:2000
221-
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, x, y), train_state)
221+
_, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss,(g, x, y), train_state)
222222

223223
if iter % 100 == 0
224-
println("Epoch: $(iter) Loss: $(loss)")
224+
println("Epoch: $(iter) Loss: $(loss_value)")
225225
end
226226
end
227227

0 commit comments

Comments
 (0)