Skip to content

Commit 7f14dc9

Browse files
stuff
1 parent 683c8b7 commit 7f14dc9

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

examples/link_prediction_pubmed.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,19 @@ function train(; kws...)
4949
### LOAD DATA
5050
data = Cora.dataset()
5151
# data = PubMed.dataset()
52-
g = GNNGraph(data.adjacency_list) |> device
52+
g = GNNGraph(data.adjacency_list)
53+
@info g
5354
@show is_bidirected(g)
55+
@show has_self_loops(g)
56+
@show has_multi_edges(g)
57+
@show mean(degree(g))
58+
59+
g = g |> device
5460
X = data.node_features |> device
5561

5662
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
57-
s, t = edge_index(g)
58-
eids = randperm(g.num_edges)
59-
test_size = round(Int, g.num_edges * 0.1)
60-
61-
test_pos_s, test_pos_t = s[eids[1:test_size]], t[eids[1:test_size]]
62-
test_pos_g = GNNGraph(test_pos_s, test_pos_t, num_nodes=g.num_nodes)
63-
64-
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]
65-
train_pos_g = GNNGraph(train_pos_s, train_pos_t, num_nodes=g.num_nodes)
66-
67-
test_neg_g = negative_sample(g, num_neg_edges=test_size)
68-
63+
train_pos_g, test_pos_g = rand_edge_split(g, 0.9)
64+
test_neg_g = negative_sample(g, num_neg_edges=test_pos_g.num_edges)
6965

7066
### DEFINE MODEL #########
7167
nin, nhidden = size(X,1), args.nhidden
@@ -82,7 +78,7 @@ function train(; kws...)
8278

8379
### LOSS FUNCTION ############
8480

85-
function loss(pos_g, neg_g = nothing)
81+
function loss(pos_g, neg_g = nothing; with_accuracy=false)
8682
h = model(X)
8783
if neg_g === nothing
8884
# We sample a negative graph at each training step
@@ -92,14 +88,20 @@ function train(; kws...)
9288
neg_score = pred(neg_g, h)
9389
scores = [pos_score; neg_score]
9490
labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
95-
return logitbinarycrossentropy(scores, labels)
91+
l = logitbinarycrossentropy(scores, labels)
92+
if with_accuracy
93+
acc = 0.5 * mean(pos_score .>= 0) + 0.5 * mean(neg_score .< 0)
94+
return l, acc
95+
else
96+
return l
97+
end
9698
end
9799

98100
### LOGGING FUNCTION
99101
function report(epoch)
100-
train_loss = loss(train_pos_g)
101-
test_loss = loss(test_pos_g, test_neg_g)
102-
println("Epoch: $epoch Train: $(train_loss) Test: $(test_loss)")
102+
train_loss, train_acc = loss(train_pos_g, with_accuracy=true)
103+
test_loss, test_acc = loss(test_pos_g, test_neg_g, with_accuracy=true)
104+
println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))")
103105
end
104106

105107
### TRAINING

0 commit comments

Comments
 (0)