We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 322bb24 commit 0e8a9caCopy full SHA for 0e8a9ca
examples/cora.jl
@@ -18,8 +18,8 @@ end
18
@functor GNN
19
20
function GNN(; nin, nhidden, nout)
21
- GNN(GraphConv(nin => nhidden, relu),
22
- GraphConv(nhidden => nhidden, relu),
+ GNN(GCNConv(nin => nhidden, relu),
+ GCNConv(nhidden => nhidden, relu),
23
Dense(nhidden, nout))
24
end
25
@@ -70,7 +70,8 @@ function train(; kws...)
70
train_ids = data.train_indices |> device
71
val_ids = data.val_indices |> device
72
test_ids = data.test_indices |> device
73
-
+ ytrain = y[:,train_ids]
74
+
75
model = GNN(nin=size(X,1),
76
nhidden=args.nhidden,
77
nout=data.num_classes) |> device
@@ -90,7 +91,7 @@ function train(; kws...)
90
91
for epoch in 1:args.epochs
92
gs = Flux.gradient(ps) do
93
ŷ = model(fg, X)
- logitcrossentropy(ŷ[:,train_ids], y[:,train_ids])
94
+ logitcrossentropy(ŷ[:,train_ids], ytrain)
95
96
97
Flux.Optimise.update!(opt, ps, gs)
0 commit comments