Skip to content

Commit 683c8b7

Browse files
Merge pull request #69 from CarloLucibello/cl/negative
better negative sampling + remove_multi_edges + is_bidirected + DotDecoder
2 parents e5e919c + e4b7557 commit 683c8b7

File tree

15 files changed

+266
-113
lines changed

15 files changed

+266
-113
lines changed

README.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,21 @@
55
![](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg)
66
[![codecov](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl)
77

8-
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). Its features include:
8+
GraphNeuralNetworks.jl is a graph neural network library written in Julia and based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl).
99

10-
* Integration with [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl).
11-
* Implementation of common graph convolutional layers.
12-
* Fast operations on batched graphs.
10+
Among its features:
11+
12+
* Implements common graph convolutional layers.
13+
* Supports computations on batched graphs.
1314
* Easy to define custom layers.
15+
* Integration with the JuliaGraphs ecosystem.
1416
* CUDA support.
17+
* Integration with [Graph.jl](https://github.com/JuliaGraphs/Graphs.jl).
18+
* [Examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks.
1519

1620
## Installation
1721

18-
GraphNeuralNetworks.jl is a registered julia package.
19-
You can easily install it through the package manager:
22+
GNN.jl is a registered julia package. You can easily install it through the package manager:
2023

2124
```julia
2225
pkg> add GraphNeuralNetworks
@@ -26,6 +29,9 @@ pkg> add GraphNeuralNetworks
2629

2730
Usage examples can be found in the [examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) folder. Also, make sure to read the [documentation](https://CarloLucibello.github.io/GraphNeuralNetworks.jl/dev) for a comprehensive introduction to the library.
2831

29-
## Acknowledgements
32+
## Acknowledgments
33+
34+
GNN.jl is largely inspired by [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/),[Deep Graph Library](https://docs.dgl.ai/),
35+
and [GeometricFlux.jl](https://fluxml.ai/GeometricFlux.jl/stable/).
36+
3037

31-
A big thanks goes to @yuehhua for creating [GeometricFlux.jl](https://github.com/FluxML/GeometricFlux.jl) of which GraphNeuralNetworks.jl is a radical redesign.

docs/src/gnngraph.md

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,6 @@ g′ = remove_self_loops(g)
138138
g′ = add_edges(g, [1, 2], [2, 3]) # add edges 1->2 and 2->3
139139
```
140140

141-
## JuliaGraphs ecosystem integration
142-
143-
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.jl.
144-
145-
```julia
146-
@assert Graphs.isdirected(g)
147-
```
148-
149141
## GPU movement
150142

151143
Move a `GNNGraph` to a CUDA device using `Flux.gpu` method.
@@ -155,3 +147,34 @@ using Flux: gpu
155147

156148
g_gpu = g |> gpu
157149
```
150+
151+
## JuliaGraphs/Graphs.jl integration
152+
153+
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.jl.
154+
Moreover, `GNNGraph`s can be constructed from `Graphs.Graph` and `Graphs.DiGraph`.
155+
156+
```julia
157+
julia> import Graphs
158+
159+
julia> using GraphNeuralNetworks
160+
161+
# A Graphs.jl undirected graph
162+
julia> gu = Graphs.erdos_renyi(10, 20)
163+
{10, 20} undirected simple Int64 graph
164+
165+
# Since GNNGraphs are undirected, the edges are doubled when converting
166+
# to GNNGraph
167+
julia> GNNGraph(gu) # Since GNNGraphs are
168+
GNNGraph:
169+
num_nodes = 10
170+
num_edges = 40
171+
172+
# A Graphs.jl directed graph
173+
julia> gd = Graphs.erdos_renyi(10, 20, is_directed=true)
174+
{10, 20} directed simple Int64 graph
175+
176+
julia> GNNGraph(gd)
177+
GNNGraph:
178+
num_nodes = 10
179+
num_edges = 20
180+
```

docs/src/index.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# GraphNeuralNetworks
22

3-
This is the documentation page for the [GraphNeuralNetworks.jl](https://github.com/CarloLucibello/GraphNeuralNetworks.jl) library.
4-
5-
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). GNN.jl is largely inspired by python's libraries [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) and [Deep Graph Library](https://docs.dgl.ai/),
6-
and by julia's [GeometricFlux](https://fluxml.ai/GeometricFlux.jl/stable/).
3+
This is the documentation page for [GraphNeuralNetworks.jl](https://github.com/CarloLucibello/GraphNeuralNetworks.jl), a graph neural network library written in Julia and based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl).
4+
GNN.jl is largely inspired by [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/),[Deep Graph Library](https://docs.dgl.ai/),
5+
and [GeometricFlux.jl](https://fluxml.ai/GeometricFlux.jl/stable/).
76

87
Among its features:
98

10-
* Integratation with the JuliaGraphs ecosystem.
11-
* Implementation of common graph convolutional layers.
12-
* Fast operations on batched graphs.
9+
* Implements common graph convolutional layers.
10+
* Supports computations on batched graphs.
1311
* Easy to define custom layers.
12+
* Integration with the JuliaGraphs ecosystem.
1413
* CUDA support.
14+
* Integration with [Graph.jl](https://github.com/JuliaGraphs/Graphs.jl).
15+
* [Examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks.
1516

1617

1718
## Package overview
@@ -45,7 +46,6 @@ GNNGraph:
4546
num_graphs = 1000
4647
ndata:
4748
x => (16, 10000)
48-
edata:
4949
gdata:
5050
y => (1000,)
5151
```

examples/link_prediction_pubmed.jl

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# An example of link prediction using negative and positive samples.
22
# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
3+
# See the comparison paper https://arxiv.org/pdf/2102.12557.pdf for more details
34

45
using Flux
5-
# Link prediction task
6-
# https://arxiv.org/pdf/2102.12557.pdf
7-
86
using Flux: onecold, onehotbatch
97
using Flux.Losses: logitbinarycrossentropy
108
using GraphNeuralNetworks
@@ -24,19 +22,19 @@ Base.@kwdef mutable struct Args
2422
infotime = 10 # report every `infotime` epochs
2523
end
2624

25+
# We define our own edge prediction layer but could also
26+
# use GraphNeuralNetworks.DotDecoder instead.
2727
struct DotPredictor end
2828

2929
function (::DotPredictor)(g, x)
3030
z = apply_edges((xi, xj, e) -> sum(xi .* xj, dims=1), g, xi=x, xj=x)
31+
# z = apply_edges(xi_dot_xj, g, xi=x, xj=x) # Same with buit-in methods
3132
return vec(z)
3233
end
3334

34-
using ChainRulesCore
35-
3635
function train(; kws...)
37-
# args = Args(; kws...)
38-
args = Args()
39-
36+
args = Args(; kws...)
37+
4038
args.seed > 0 && Random.seed!(args.seed)
4139

4240
if args.usecuda && CUDA.functional()
@@ -50,10 +48,11 @@ function train(; kws...)
5048

5149
### LOAD DATA
5250
data = Cora.dataset()
51+
# data = PubMed.dataset()
5352
g = GNNGraph(data.adjacency_list) |> device
53+
@show is_bidirected(g)
5454
X = data.node_features |> device
5555

56-
5756
#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
5857
s, t = edge_index(g)
5958
eids = randperm(g.num_edges)
@@ -67,9 +66,11 @@ function train(; kws...)
6766

6867
test_neg_g = negative_sample(g, num_neg_edges=test_size)
6968

69+
7070
### DEFINE MODEL #########
7171
nin, nhidden = size(X,1), args.nhidden
7272

73+
# We embed the graph with positive training edges in the model
7374
model = WithGraph(GNNChain(GCNConv(nin => nhidden, relu),
7475
GCNConv(nhidden => nhidden)),
7576
train_pos_g) |> device
@@ -84,7 +85,7 @@ function train(; kws...)
8485
function loss(pos_g, neg_g = nothing)
8586
h = model(X)
8687
if neg_g === nothing
87-
# we sample a negative graph at each training step
88+
# We sample a negative graph at each training step
8889
neg_g = negative_sample(pos_g)
8990
end
9091
pos_score = pred(pos_g, h)
@@ -93,15 +94,6 @@ function train(; kws...)
9394
labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
9495
return logitbinarycrossentropy(scores, labels)
9596
end
96-
97-
# function accuracy(pos_g, neg_g)
98-
# h = model(train_pos_g, X)
99-
# pos_score = pred(pos_g, h)
100-
# neg_score = pred(neg_g, h)
101-
# scores = [pos_score; neg_score]
102-
# labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
103-
# return logitbinarycrossentropy(scores, labels)
104-
# end
10597

10698
### LOGGING FUNCTION
10799
function report(epoch)
File renamed without changes.

src/GNNGraphs/GNNGraphs.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,43 @@ using ChainRulesCore
1616
using LinearAlgebra, Random
1717

1818
include("gnngraph.jl")
19-
export GNNGraph, node_features, edge_features, graph_features
19+
export GNNGraph,
20+
node_features,
21+
edge_features,
22+
graph_features
2023

2124
include("query.jl")
22-
export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
23-
graph_indicator
25+
export edge_index,
26+
adjacency_list,
27+
normalized_laplacian,
28+
scaled_laplacian,
29+
graph_indicator,
30+
is_bidirected,
31+
# from Graphs
32+
adjacency_matrix,
33+
degree,
34+
outneighbors,
35+
inneighbors
2436

2537
include("transform.jl")
26-
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph,
27-
negative_sample
38+
export add_nodes,
39+
add_edges,
40+
add_self_loops,
41+
remove_self_loops,
42+
remove_multi_edges,
43+
getgraph,
44+
negative_sample,
45+
# from Flux
46+
batch,
47+
unbatch,
48+
# from SparseArrays
49+
blockdiag
2850

2951
include("generate.jl")
3052
export rand_graph
3153

32-
3354
include("convert.jl")
3455
include("utils.jl")
3556

36-
export
37-
# from Graphs
38-
adjacency_matrix, degree, outneighbors, inneighbors,
39-
# from SparseArrays
40-
sprand, sparse, blockdiag,
41-
# from Flux
42-
batch, unbatch
43-
57+
4458
end #module

src/GNNGraphs/generate.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@ julia> g = rand_graph(5, 4, bidirected=false)
1919
GNNGraph:
2020
num_nodes = 5
2121
num_edges = 4
22-
num_graphs = 1
23-
ndata:
24-
edata:
25-
gdata:
26-
2722
2823
julia> edge_index(g)
2924
([1, 3, 3, 4], [5, 4, 5, 2])
@@ -33,11 +28,8 @@ julia> g = rand_graph(5, 4, edata=rand(16, 2))
3328
GNNGraph:
3429
num_nodes = 5
3530
num_edges = 4
36-
num_graphs = 1
37-
ndata:
3831
edata:
3932
e => (16, 4)
40-
gdata:
4133
4234
# Each edge has a reverse
4335
julia> edge_index(g)

src/GNNGraphs/gnngraph.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,19 +195,25 @@ end
195195
function Base.show(io::IO, g::GNNGraph)
196196
println(io, "GNNGraph:
197197
num_nodes = $(g.num_nodes)
198-
num_edges = $(g.num_edges)
199-
num_graphs = $(g.num_graphs)")
200-
println(io, " ndata:")
201-
for k in keys(g.ndata)
202-
println(io, " $k => $(size(g.ndata[k]))")
198+
num_edges = $(g.num_edges)")
199+
g.num_graphs > 1 && println("num_graphs = $(g.num_graphs)")
200+
if !isempty(g.ndata)
201+
println(io, " ndata:")
202+
for k in keys(g.ndata)
203+
println(io, " $k => $(size(g.ndata[k]))")
204+
end
203205
end
204-
println(io, " edata:")
205-
for k in keys(g.edata)
206-
println(io, " $k => $(size(g.edata[k]))")
206+
if !isempty(g.edata)
207+
println(io, " edata:")
208+
for k in keys(g.edata)
209+
println(io, " $k => $(size(g.edata[k]))")
210+
end
207211
end
208-
println(io, " gdata:")
209-
for k in keys(g.gdata)
210-
println(io, " $k => $(size(g.gdata[k]))")
212+
if !isempty(g.gdata)
213+
println(io, " gdata:")
214+
for k in keys(g.gdata)
215+
println(io, " $k => $(size(g.gdata[k]))")
216+
end
211217
end
212218
end
213219

src/GNNGraphs/query.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,20 @@ function graph_features(g::GNNGraph)
229229
end
230230
end
231231

232+
"""
233+
is_bidirected(g::GNNGraph)
234+
235+
Check if the directed graph `g` essentially corresponds
236+
to an undirected graph, i.e. if for each edge it also contains the
237+
reverse edge.
238+
"""
239+
function is_bidirected(g::GNNGraph)
240+
s, t = edge_index(g)
241+
s1, t1 = sort_edge_index(s, t)
242+
s2, t2 = sort_edge_index(t, s)
243+
all((s1 .== s2) .& (t1 .== t2))
244+
end
245+
232246
@non_differentiable normalized_laplacian(x...)
233247
@non_differentiable normalized_adjacency(x...)
234248
@non_differentiable scaled_laplacian(x...)

0 commit comments

Comments
 (0)