Skip to content

Commit f96f58e

Browse files
wip on rand_split_edge
1 parent d5c7a80 commit f96f58e

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

examples/link_prediction_pubmed.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct DotPredictor end
2828

2929
function (::DotPredictor)(g, x)
3030
z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims=1), g, xi=x, xj=x)
31-
# z = apply_edges(xi_dot_xj, g, xi=x, xj=x) # Same with buit-in methods
31+
# z = apply_edges(xi_dot_xj, g, xi=x, xj=x) # Same with built-in method
3232
return vec(z)
3333
end
3434

@@ -50,18 +50,23 @@ function train(; kws...)
5050
data = Cora.dataset()
5151
# data = PubMed.dataset()
5252
g = GNNGraph(data.adjacency_list)
53+
54+
# Print some info
5355
@info g
5456
@show is_bidirected(g)
5557
@show has_self_loops(g)
5658
@show has_multi_edges(g)
5759
@show mean(degree(g))
5860
isbidir = is_bidirected(g)
5961

62+
# Move to device
6063
g = g |> device
6164
X = data.node_features |> device
6265

63-
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
64-
train_pos_g, test_pos_g = rand_edge_split(g, 0.9)
66+
#### TRAIN/TEST splits
67+
# With bidirected graph, we make sure that an edge and its reverse
68+
# are in the same split
69+
train_pos_g, test_pos_g = rand_edge_split(g, 0.9, bidirected=isbidir)
6570
test_neg_g = negative_sample(g, num_neg_edges=test_pos_g.num_edges, bidirected=isbidir)
6671

6772
### DEFINE MODEL #########

0 commit comments

Comments
 (0)