18
18
@functor GNN
19
19
20
20
function GNN (; nin, nhidden, nout)
21
- GNN (GCNConv (nin => nhidden, relu),
22
- GCNConv (nhidden => nhidden, relu),
21
+ GNN (GraphConv (nin => nhidden, relu),
22
+ GraphConv (nhidden => nhidden, relu),
23
23
Dense (nhidden, nout))
24
24
end
25
25
@@ -35,13 +35,9 @@ function eval_loss_accuracy(X, y, ids, model, fg)
35
35
ŷ = model (fg, X)
36
36
l = logitcrossentropy (ŷ[:,ids], y[:,ids])
37
37
acc = mean (onecold (ŷ[:,ids] |> cpu) .== onecold (y[:,ids] |> cpu))
38
- return (loss = l |> round4 , acc = acc* 100 |> round4 )
38
+ return (loss = round (l, digits = 4 ) , acc = round ( acc* 100 , digits = 2 ) )
39
39
end
40
40
41
- # # utility functions
42
- num_params (model) = sum (length, Flux. params (model))
43
- round4 (x) = round (x, digits= 4 )
44
-
45
41
# arguments for the `train` function
46
42
Base. @kwdef mutable struct Args
47
43
η = 1f-3 # learning rate
@@ -68,7 +64,7 @@ function train(; kws...)
68
64
end
69
65
70
66
data = Cora. dataset ()
71
- fg = FeaturedGraph (data. adjacency_list)
67
+ fg = FeaturedGraph (data. adjacency_list) |> device
72
68
X = data. node_features |> device
73
69
y = onehotbatch (data. node_labels, 1 : data. num_classes) |> device
74
70
train_ids = data. train_indices |> device
@@ -85,9 +81,8 @@ function train(; kws...)
85
81
86
82
function report (epoch)
87
83
train = eval_loss_accuracy (X, y, train_ids, model, fg)
88
- val = eval_loss_accuracy (X, y, val_ids, model, fg)
89
84
test = eval_loss_accuracy (X, y, test_ids, model, fg)
90
- println (" Epoch: $epoch Train: $(train) Val: $(val) Test: $(test) " )
85
+ println (" Epoch: $epoch Train: $(train) Test: $(test) " )
91
86
end
92
87
93
88
# # TRAINING
0 commit comments