@@ -49,23 +49,19 @@ function train(; kws...)
49
49
# ## LOAD DATA
50
50
data = Cora. dataset ()
51
51
# data = PubMed.dataset()
52
- g = GNNGraph (data. adjacency_list) |> device
52
+ g = GNNGraph (data. adjacency_list)
53
+ @info g
53
54
@show is_bidirected (g)
55
+ @show has_self_loops (g)
56
+ @show has_multi_edges (g)
57
+ @show mean (degree (g))
58
+
59
+ g = g |> device
54
60
X = data. node_features |> device
55
61
56
62
# ### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
57
- s, t = edge_index (g)
58
- eids = randperm (g. num_edges)
59
- test_size = round (Int, g. num_edges * 0.1 )
60
-
61
- test_pos_s, test_pos_t = s[eids[1 : test_size]], t[eids[1 : test_size]]
62
- test_pos_g = GNNGraph (test_pos_s, test_pos_t, num_nodes= g. num_nodes)
63
-
64
- train_pos_s, train_pos_t = s[eids[test_size+ 1 : end ]], t[eids[test_size+ 1 : end ]]
65
- train_pos_g = GNNGraph (train_pos_s, train_pos_t, num_nodes= g. num_nodes)
66
-
67
- test_neg_g = negative_sample (g, num_neg_edges= test_size)
68
-
63
+ train_pos_g, test_pos_g = rand_edge_split (g, 0.9 )
64
+ test_neg_g = negative_sample (g, num_neg_edges= test_pos_g. num_edges)
69
65
70
66
# ## DEFINE MODEL #########
71
67
nin, nhidden = size (X,1 ), args. nhidden
@@ -82,7 +78,7 @@ function train(; kws...)
82
78
83
79
# ## LOSS FUNCTION ############
84
80
85
- function loss (pos_g, neg_g = nothing )
81
+ function loss (pos_g, neg_g = nothing ; with_accuracy = false )
86
82
h = model (X)
87
83
if neg_g === nothing
88
84
# We sample a negative graph at each training step
@@ -92,14 +88,20 @@ function train(; kws...)
92
88
neg_score = pred (neg_g, h)
93
89
scores = [pos_score; neg_score]
94
90
labels = [fill! (similar (pos_score), 1 ); fill! (similar (neg_score), 0 )]
95
- return logitbinarycrossentropy (scores, labels)
91
+ l = logitbinarycrossentropy (scores, labels)
92
+ if with_accuracy
93
+ acc = 0.5 * mean (pos_score .>= 0 ) + 0.5 * mean (neg_score .< 0 )
94
+ return l, acc
95
+ else
96
+ return l
97
+ end
96
98
end
97
99
98
100
# ## LOGGING FUNCTION
99
101
function report (epoch)
100
- train_loss = loss (train_pos_g)
101
- test_loss = loss (test_pos_g, test_neg_g)
102
- println (" Epoch: $epoch Train: $( train_loss) Test: $( test_loss)" )
102
+ train_loss, train_acc = loss (train_pos_g, with_accuracy = true )
103
+ test_loss, test_acc = loss (test_pos_g, test_neg_g, with_accuracy = true )
104
+ println (" Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc) )" )
103
105
end
104
106
105
107
# ## TRAINING
0 commit comments