Skip to content

Commit 8fdf813

Browse files
wip
1 parent 7387bab commit 8fdf813

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1415
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1516
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

examples/cora.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ end
1818
@functor GNN
1919

2020
function GNN(; nin, nhidden, nout)
21-
GNN(GCNConv(nin => nhidden, relu),
22-
GCNConv(nhidden => nhidden, relu),
21+
GNN(GraphConv(nin => nhidden, relu),
22+
GraphConv(nhidden => nhidden, relu),
2323
Dense(nhidden, nout))
2424
end
2525

@@ -35,13 +35,9 @@ function eval_loss_accuracy(X, y, ids, model, fg)
3535
= model(fg, X)
3636
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
3737
acc = mean(onecold(ŷ[:,ids] |> cpu) .== onecold(y[:,ids] |> cpu))
38-
return (loss = l |> round4, acc = acc*100 |> round4)
38+
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
3939
end
4040

41-
## utility functions
42-
num_params(model) = sum(length, Flux.params(model))
43-
round4(x) = round(x, digits=4)
44-
4541
# arguments for the `train` function
4642
Base.@kwdef mutable struct Args
4743
η = 1f-3 # learning rate
@@ -68,7 +64,7 @@ function train(; kws...)
6864
end
6965

7066
data = Cora.dataset()
71-
fg = FeaturedGraph(data.adjacency_list)
67+
fg = FeaturedGraph(data.adjacency_list) |> device
7268
X = data.node_features |> device
7369
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
7470
train_ids = data.train_indices |> device
@@ -85,9 +81,8 @@ function train(; kws...)
8581

8682
function report(epoch)
8783
train = eval_loss_accuracy(X, y, train_ids, model, fg)
88-
val = eval_loss_accuracy(X, y, val_ids, model, fg)
8984
test = eval_loss_accuracy(X, y, test_ids, model, fg)
90-
println("Epoch: $epoch Train: $(train) Val: $(val) Test: $(test)")
85+
println("Epoch: $epoch Train: $(train) Test: $(test)")
9186
end
9287

9388
## TRAINING

0 commit comments

Comments
 (0)