Skip to content

Commit b2f6ebf

Browse files
Merge pull request #47 from CarloLucibello/cl/redesign
redesign message passing mechanism
2 parents 57b0c56 + d4ddd95 commit b2f6ebf

25 files changed

+710
-581
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.1.2"
4+
version = "0.2.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -25,7 +25,7 @@ Adapt = "3"
2525
CUDA = "3.3"
2626
ChainRulesCore = "1"
2727
DataStructures = "0.18"
28-
Flux = "0.12"
28+
Flux = "0.12.7"
2929
KrylovKit = "0.5"
3030
LearnBase = "0.4, 0.5"
3131
LightGraphs = "1.3"

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
![](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg)
66
[![codecov](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl)
77

8-
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl).
9-
Its most relevant features are:
10-
* Provides CUDA support.
11-
* It's integrated with the JuliaGraphs ecosystem.
12-
* Implements many common graph convolutional layers.
13-
* Performs fast operations on batched graphs.
14-
* Makes it easy to define custom graph convolutional layers.
8+
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). Its features include:
9+
10+
* Integratation with the JuliaGraphs ecosystem.
11+
* Implementation of common graph convolutional layers.
12+
* Fast operations on batched graphs.
13+
* Easy to define custom layers.
14+
* CUDA support.
1515

1616
## Installation
1717

@@ -28,4 +28,4 @@ Usage examples can be found in the [examples](https://github.com/CarloLucibello/
2828

2929
## Acknowledgements
3030

31-
A big thank you goes to @yuehhua for creating [GeometricFlux.jl](https://github.com/FluxML/GeometricFlux.jl) of which GraphNeuralNetworks.jl is a radical redesign.
31+
A big thanks goes to @yuehhua for creating [GeometricFlux.jl](https://github.com/FluxML/GeometricFlux.jl) of which GraphNeuralNetworks.jl is a radical redesign.

docs/src/api/messagepassing.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ Order = [:type, :function]
1111
Pages = ["messagepassing.md"]
1212
```
1313

14-
## Docs
14+
## Interface
1515

1616
```@docs
17-
compute_message
18-
update_node
19-
update_edge
17+
apply_edges
2018
propagate
2119
```
20+
21+
## Built-in message functions
22+
23+
```@docs
24+
copyxj
25+
```

docs/src/gnngraph.md

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Graphs
22

3-
The fundamental graph type in GraphNeuralNetworks.jl is the [`GNNGraph`](@ref),
3+
The fundamental graph type in GraphNeuralNetworks.jl is the [`GNNGraph`](@ref).
44
A GNNGraph `g` is a directed graph with nodes labeled from 1 to `g.num_nodes`.
55
The underlying implementation allows for efficient application of graph neural network
66
operators, gpu movement, and storage of node/edge/graph related feature arrays.
@@ -32,25 +32,50 @@ g = GNNGraph(source, target)
3232

3333
See also the related methods [`adjacency_matrix`](@ref), [`edge_index`](@ref), and [`adjacency_list`](@ref).
3434

35+
## Basic Queries
36+
37+
```julia
38+
source = [1,1,2,2,3,3,3,4]
39+
target = [2,3,1,3,1,2,4,3]
40+
g = GNNGraph(source, target)
41+
42+
@assert g.num_nodes == 4 # number of nodes
43+
@assert g.num_edges == 8 # number of edges
44+
@assert g.num_graphs == 1 # number of subgraphs (a GNNGraph can batch many graphs together)
45+
is_directed(g) # a GGNGraph is always directed
46+
```
3547

3648
## Data Features
3749

50+
One or more arrays can be associated to nodes, edges, and (sub)graphs of a `GNNGraph`.
51+
They will be stored in the fields `g.ndata`, `g.edata`, and `g.gdata` respectivaly.
52+
The data fields are `NamedTuple`s. The array they contain must have last dimension
53+
equal to `num_nodes` (in `ndata`), `num_edges` (in `edata`), or `num_graphs` (in `gdata`).
54+
3855
```julia
3956
# Create a graph with a single feature array `x` associated to nodes
4057
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x = rand(Float32, 32, 10)))
41-
# Equivalent definition
58+
59+
g.ndata.x # access the features
60+
61+
# Equivalent definition passing directly the array
4262
g = GNNGraph(erdos_renyi(10, 30), ndata = rand(Float32, 32, 10))
4363

64+
g.ndata.x # `:x` is the default name for node features
65+
4466
# You can have multiple feature arrays
4567
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x=rand(Float32, 32, 10), y=rand(Float32, 10)))
4668

69+
g.ndata.y, g.ndata.x
4770

4871
# Attach an array with edge features.
4972
# Since `GNNGraph`s are directed, the number of edges
5073
# will be double that of the original LightGraphs' undirected graph.
5174
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 60))
5275
@assert g.num_edges == 60
5376

77+
g.edata.e
78+
5479
# If we pass only half of the edge features, they will be copied
5580
# on the reversed edges.
5681
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
@@ -59,28 +84,30 @@ g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
5984
# Create a new graph from previous one, inheriting edge data
6085
# but replacing node data
6186
g′ = GNNGraph(g, ndata =(; z = ones(Float32, 16, 10)))
62-
```
63-
64-
65-
## Graph Manipulation
6687

67-
```julia
68-
g′ = add_self_loops(g)
69-
70-
g′ = remove_self_loops(g)
88+
g.ndata.z
89+
g.edata.e
7190
```
7291

7392
## Batches and Subgraphs
7493

94+
Multiple `GNNGraph`s can be batched togheter into a single graph
95+
containing the total number of the original nodes
96+
and where the original graphs are disjoint subgraphs.
97+
7598
```julia
7699
using Flux
77100

78101
gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Float32,3,10)) for _ in 1:160])
79102

80-
g23 = getgraph(gall, 2:3)
103+
@assert gall.num_graphs == 160
104+
@assert gall.num_nodes == 1600 # 10 nodes x 160 graphs
105+
@assert gall.num_edges == 9600 # 30 undirected edges x 2 directions x 160 graphs
106+
107+
g23, _ = getgraph(gall, 2:3)
81108
@assert g23.num_graphs == 2
82-
@assert g23.num_nodes == 20
83-
@assert g23.num_edges == 120 # 30 undirected edges x 2 graphs
109+
@assert g23.num_nodes == 20 # 10 nodes x 160 graphs
110+
@assert g23.num_edges == 120 # 30 undirected edges x 2 directions x 2 graphs x
84111

85112

86113
# DataLoader compatibility
@@ -92,6 +119,17 @@ for g in train_loader
92119
@assert size(g.ndata.x) = (3, 160)
93120
.....
94121
end
122+
123+
# Access the nodes' graph memberships through
124+
gall.graph_indicator
125+
```
126+
127+
## Graph Manipulation
128+
129+
```julia
130+
g′ = add_self_loops(g)
131+
132+
g′ = remove_self_loops(g)
95133
```
96134

97135
## JuliaGraphs ecosystem integration

docs/src/index.md

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,29 @@
22

33
This is the documentation page for the [GraphNeuralNetworks.jl](https://github.com/CarloLucibello/GraphNeuralNetworks.jl) library.
44

5-
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl).
6-
Its most relevant features are:
7-
* Provides CUDA support.
8-
* It's integrated with the JuliaGraphs ecosystem.
9-
* Implements many common graph convolutional layers.
10-
* Performs fast operations on batched graphs.
11-
* Makes it easy to define custom graph convolutional layers.
5+
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). GNN.jl is largely inspired by python's libraries [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) and [Deep Graph Library](https://docs.dgl.ai/),
6+
and by julia's [GeometricFlux](https://fluxml.ai/GeometricFlux.jl/stable/).
7+
8+
Among its features:
9+
10+
* Integratation with the JuliaGraphs ecosystem.
11+
* Implementation of common graph convolutional layers.
12+
* Fast operations on batched graphs.
13+
* Easy to define custom layers.
14+
* CUDA support.
1215

1316

1417
## Package overview
1518

16-
Let's give a brief overview of the package solving a
17-
graph regression problem on fake data.
19+
Let's give a brief overview of the package by solving a
20+
graph regression problem with synthetic data.
1821

1922
Usage examples on real datasets can be found in the [examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) folder.
2023

2124
### Data preparation
2225

2326
First, we create our dataset consisting in multiple random graphs and associated data features.
24-
that we batch together into a unique graph.
27+
Then we batch the graphs together into a unique graph.
2528

2629
```julia
2730
julia> using GraphNeuralNetworks, LightGraphs, Flux, CUDA, Statistics
@@ -50,17 +53,17 @@ GNNGraph:
5053

5154
### Model building
5255

53-
We concisely define our model using as a [`GNNChain`](@ref) containing 2 graph convolutaional
56+
We concisely define our model as a [`GNNChain`](@ref) containing 2 graph convolutaional
5457
layers. If CUDA is available, our model will live on the gpu.
5558

5659
```julia
5760
julia> device = CUDA.functional() ? Flux.gpu : Flux.cpu;
5861

5962
julia> model = GNNChain(GCNConv(16 => 64),
60-
BatchNorm(64),
61-
x -> relu.(x),
63+
BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension)
64+
x -> relu.(x),
6265
GCNConv(64 => 64, relu),
63-
GlobalPool(mean),
66+
GlobalPool(mean), # aggregate node-wise features into graph-wise features
6467
Dense(64, 1)) |> device;
6568

6669
julia> ps = Flux.params(model);
@@ -75,8 +78,8 @@ Flux's DataLoader iterates over mini-batches of graphs
7578
(batched together into a `GNNGraph` object).
7679

7780
```julia
78-
gtrain, _ = getgraph(gbatch, 1:800)
79-
gtest, _ = getgraph(gbatch, 801:gbatch.num_graphs)
81+
gtrain = getgraph(gbatch, 1:800)
82+
gtest = getgraph(gbatch, 801:gbatch.num_graphs)
8083
train_loader = Flux.Data.DataLoader(gtrain, batchsize=32, shuffle=true)
8184
test_loader = Flux.Data.DataLoader(gtest, batchsize=32, shuffle=false)
8285

@@ -86,7 +89,7 @@ loss(loader) = mean(loss(g |> device) for g in loader)
8689

8790
for epoch in 1:100
8891
for g in train_loader
89-
g = g |> gpu
92+
g = g |> device
9093
grad = gradient(() -> loss(g), ps)
9194
Flux.Optimise.update!(opt, ps, grad)
9295
end

docs/src/messagepassing.md

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,55 @@
11
# Message Passing
22

3-
The message passing is initiated by [`propagate`](@ref)
4-
and can be customized for a specific layer by overloading the methods
5-
[`compute_message`](@ref), [`update_node`](@ref), and [`update_edge`](@ref).
6-
7-
The message passing corresponds to the following operations
3+
A generic message passing on graph takes the form
84

95
```math
106
\begin{aligned}
117
\mathbf{m}_{j\to i} &= \phi(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{j\to i}) \\
12-
\mathbf{x}_{i}' &= \gamma_x(\mathbf{x}_{i}, \square_{j\in N(i)} \mathbf{m}_{j\to i})\\
8+
\bar{\mathbf{m}}_{i} &= \square_{j\in N(i)} \mathbf{m}_{j\to i} \\
9+
\mathbf{x}_{i}' &= \gamma_x(\mathbf{x}_{i}, \bar{\mathbf{m}}_{i})\\
1310
\mathbf{e}_{j\to i}^\prime &= \gamma_e(\mathbf{e}_{j \to i},\mathbf{m}_{j \to i})
1411
\end{aligned}
1512
```
16-
where ``\phi`` is expressed by the [`compute_message`](@ref) function,
17-
``\gamma_x`` and ``\gamma_e`` by [`update_node`](@ref) and [`update_edge`](@ref)
18-
respectively.
1913

20-
The message propagation mechanism internally relies on the [`NNlib.gather`](@ref)
14+
where we refer to ``\phi`` as to the message function,
15+
and to ``\gamma_x`` and ``\gamma_e`` as to the node update and edge update function
16+
respectively. The aggregation ``\square`` is over the neighborhood ``N(i)`` of node ``i``,
17+
and it is usually set to summation ``\sum``, a max or a mean operation.
18+
19+
In GNN.jl, the function [`propagate`](@ref) takes care of materializing the
20+
node features on each edge, applying the message function, performing the
21+
aggregation, and returning ``\bar{\mathbf{m}}``.
22+
It is then left to the user to perform further node and edge updates,
23+
manypulating arrays of size ``D_{node} \times num_nodes`` and
24+
``D_{edge} \times num_edges``.
25+
26+
As part of the [`propagate`](@ref) pipeline, we have the function
27+
[`apply_edges`](@ref). It can be independently used to materialize
28+
node features on edges and perform edge-related computation without
29+
the following neighborhood aggregation one finds in `propagate`.
30+
31+
The whole propagation mechanism internally relies on the [`NNlib.gather`](@ref)
2132
and [`NNlib.scatter`](@ref) methods.
2233

23-
## An example: implementing the GCNConv
2434

25-
Let's (re-)implement the [`GCNConv`](@ref) layer use the message passing framework.
35+
## Examples
36+
37+
### Basic use propagate and apply_edges
38+
39+
40+
41+
### Implementing a custom Graph Convolutional Layer
42+
43+
Let's implement a simple graph convolutional layer using the message passing framework.
2644
The convolution reads
2745

2846
```math
29-
\mathbf{x}'_i = \sum_{j \in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
47+
\mathbf{x}'_i = W \cdot \sum_{j \in N(i)} \mathbf{x}_j
3048
```
31-
where ``c_{ij} = \sqrt{|N(i)||N(j)|}``. We will also add a bias and an activation function.
49+
We will also add a bias and an activation function.
3250

3351
```julia
3452
using Flux, LightGraphs, GraphNeuralNetworks
35-
import GraphNeuralNetworks: compute_message, update_node, propagate
3653

3754
struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
3855
weight::A
@@ -49,16 +66,26 @@ function GCN(ch::Pair{Int,Int}, σ=identity)
4966
GCN(W, b, σ)
5067
end
5168

52-
compute_message(l::GCN, xi, xj, eij) = l.weight * xj
53-
update_node(l::GCN, m, x) = m
54-
5569
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
56-
c = 1 ./ sqrt.(degree(g, T, dir=:in))
57-
x = x .* c'
58-
x, _ = propagate(l, g, +, x)
59-
x = x .* c'
60-
return l.σ.(x .+ l.bias)
70+
@assert size(x, 2) == g.num_nodes
71+
72+
# Computes messages from source/neighbour nodes (j) to target/root nodes (i).
73+
# The message function will have to handle matrices of size (*, num_edges).
74+
# In this simple case we just let the neighbor features go through.
75+
message(xi, xj, e) = xj
76+
77+
# The + operator gives the sum aggregation.
78+
# `mean`, `max`, `min`, and `*` are other possibilities.
79+
x = propagate(message, g, +, xj=x)
80+
81+
return l.σ.(l.weight * x .+ l.bias)
6182
end
6283
```
6384

6485
See the [`GATConv`](@ref) implementation [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl) for a more complex example.
86+
87+
88+
## Built-in message functions
89+
90+
In order to exploit optimized specializations of the [`propagate`](@ref), it is recommended
91+
to use built-in message functions such as [`copyxj`](@ref) whenever possible.

0 commit comments

Comments
 (0)