Skip to content

Commit 0e8a9ca

Browse files
fix cuda example
1 parent 322bb24 commit 0e8a9ca

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

examples/cora.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ end
1818
@functor GNN
1919

2020
function GNN(; nin, nhidden, nout)
21-
GNN(GraphConv(nin => nhidden, relu),
22-
GraphConv(nhidden => nhidden, relu),
21+
GNN(GCNConv(nin => nhidden, relu),
22+
GCNConv(nhidden => nhidden, relu),
2323
Dense(nhidden, nout))
2424
end
2525

@@ -70,7 +70,8 @@ function train(; kws...)
7070
train_ids = data.train_indices |> device
7171
val_ids = data.val_indices |> device
7272
test_ids = data.test_indices |> device
73-
73+
ytrain = y[:,train_ids]
74+
7475
model = GNN(nin=size(X,1),
7576
nhidden=args.nhidden,
7677
nout=data.num_classes) |> device
@@ -90,7 +91,7 @@ function train(; kws...)
9091
for epoch in 1:args.epochs
9192
gs = Flux.gradient(ps) do
9293
= model(fg, X)
93-
logitcrossentropy(ŷ[:,train_ids], y[:,train_ids])
94+
logitcrossentropy(ŷ[:,train_ids], ytrain)
9495
end
9596

9697
Flux.Optimise.update!(opt, ps, gs)

0 commit comments

Comments
 (0)