Skip to content

Commit 48dcb51

Browse files
tests, docs, deprecations
1 parent 4587f7f commit 48dcb51

16 files changed

+240
-293
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
89
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

docs/src/api/messagepassing.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ Pages = ["messagepassing.md"]
1414
## Docs
1515

1616
```@docs
17-
compute_message
18-
update_node
19-
update_edge
17+
apply_edges
2018
propagate
2119
```

docs/src/gnngraph.md

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Graphs
22

3-
The fundamental graph type in GraphNeuralNetworks.jl is the [`GNNGraph`](@ref),
3+
The fundamental graph type in GraphNeuralNetworks.jl is the [`GNNGraph`](@ref).
44
A GNNGraph `g` is a directed graph with nodes labeled from 1 to `g.num_nodes`.
55
The underlying implementation allows for efficient application of graph neural network
66
operators, gpu movement, and storage of node/edge/graph related feature arrays.
@@ -32,25 +32,50 @@ g = GNNGraph(source, target)
3232

3333
See also the related methods [`adjacency_matrix`](@ref), [`edge_index`](@ref), and [`adjacency_list`](@ref).
3434

35+
## Basic Queries
36+
37+
```julia
38+
source = [1,1,2,2,3,3,3,4]
39+
target = [2,3,1,3,1,2,4,3]
40+
g = GNNGraph(source, target)
41+
42+
@assert g.num_nodes == 4 # number of nodes
43+
@assert g.num_edges == 8 # number of edges
44+
@assert g.num_graphs == 1 # number of subgraphs (a GNNGraph can batch many graphs together)
45+
is_directed(g) # a GGNGraph is always directed
46+
```
3547

3648
## Data Features
3749

50+
One or more arrays can be associated to nodes, edges, and (sub)graphs of a `GNNGraph`.
51+
They will be stored in the fields `g.ndata`, `g.edata`, and `g.gdata` respectivaly.
52+
The data fields are `NamedTuple`s. The array they contain must have last dimension
53+
equal to `num_nodes` (in `ndata`), `num_edges` (in `edata`), or `num_graphs` (in `gdata`).
54+
3855
```julia
3956
# Create a graph with a single feature array `x` associated to nodes
4057
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x = rand(Float32, 32, 10)))
41-
# Equivalent definition
58+
59+
g.ndata.x # access the features
60+
61+
# Equivalent definition passing directly the array
4262
g = GNNGraph(erdos_renyi(10, 30), ndata = rand(Float32, 32, 10))
4363

64+
g.ndata.x # `:x` is the default name for node features
65+
4466
# You can have multiple feature arrays
4567
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x=rand(Float32, 32, 10), y=rand(Float32, 10)))
4668

69+
g.ndata.y, g.ndata.x
4770

4871
# Attach an array with edge features.
4972
# Since `GNNGraph`s are directed, the number of edges
5073
# will be double that of the original LightGraphs' undirected graph.
5174
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 60))
5275
@assert g.num_edges == 60
5376

77+
g.edata.e
78+
5479
# If we pass only half of the edge features, they will be copied
5580
# on the reversed edges.
5681
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
@@ -59,28 +84,30 @@ g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
5984
# Create a new graph from previous one, inheriting edge data
6085
# but replacing node data
6186
g′ = GNNGraph(g, ndata =(; z = ones(Float32, 16, 10)))
62-
```
63-
64-
65-
## Graph Manipulation
6687

67-
```julia
68-
g′ = add_self_loops(g)
69-
70-
g′ = remove_self_loops(g)
88+
g.ndata.z
89+
g.edata.e
7190
```
7291

7392
## Batches and Subgraphs
7493

94+
Multiple `GNNGraph`s can be batched togheter into a single graph
95+
containing the total number of the original nodes
96+
and where the original graphs are disjoint subgraphs.
97+
7598
```julia
7699
using Flux
77100

78101
gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Float32,3,10)) for _ in 1:160])
79102

80-
g23 = getgraph(gall, 2:3)
103+
@assert gall.num_graphs == 160
104+
@assert gall.num_nodes == 1600 # 10 nodes x 160 graphs
105+
@assert gall.num_edges == 9600 # 30 undirected edges x 2 directions x 160 graphs
106+
107+
g23, _ = getgraph(gall, 2:3)
81108
@assert g23.num_graphs == 2
82-
@assert g23.num_nodes == 20
83-
@assert g23.num_edges == 120 # 30 undirected edges x 2 graphs
109+
@assert g23.num_nodes == 20 # 10 nodes x 160 graphs
110+
@assert g23.num_edges == 120 # 30 undirected edges x 2 directions x 2 graphs x
84111

85112

86113
# DataLoader compatibility
@@ -92,6 +119,17 @@ for g in train_loader
92119
@assert size(g.ndata.x) = (3, 160)
93120
.....
94121
end
122+
123+
# Access the nodes' graph memberships through
124+
gall.graph_indicator
125+
```
126+
127+
## Graph Manipulation
128+
129+
```julia
130+
g′ = add_self_loops(g)
131+
132+
g′ = remove_self_loops(g)
95133
```
96134

97135
## JuliaGraphs ecosystem integration

docs/src/index.md

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
This is the documentation page for the [GraphNeuralNetworks.jl](https://github.com/CarloLucibello/GraphNeuralNetworks.jl) library.
44

55
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl).
6-
Its most relevant features are:
7-
* Provides CUDA support.
8-
* It's integrated with the JuliaGraphs ecosystem.
9-
* Implements many common graph convolutional layers.
10-
* Performs fast operations on batched graphs.
11-
* Makes it easy to define custom graph convolutional layers.
6+
7+
Among its features:
8+
9+
* Integratation with the JuliaGraphs ecosystem.
10+
* Provides common graph convolutional layers.
11+
* Fast operations on batched graphs.
12+
* Easy to define custom graph convolutional layers.
13+
* CUDA support.
1214

1315

1416
## Package overview
@@ -50,17 +52,17 @@ GNNGraph:
5052

5153
### Model building
5254

53-
We concisely define our model using as a [`GNNChain`](@ref) containing 2 graph convolutaional
55+
We concisely define our model as a [`GNNChain`](@ref) containing 2 graph convolutaional
5456
layers. If CUDA is available, our model will live on the gpu.
5557

5658
```julia
5759
julia> device = CUDA.functional() ? Flux.gpu : Flux.cpu;
5860

5961
julia> model = GNNChain(GCNConv(16 => 64),
60-
BatchNorm(64),
61-
x -> relu.(x),
62+
BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension)
63+
x -> relu.(x),
6264
GCNConv(64 => 64, relu),
63-
GlobalPool(mean),
65+
GlobalPool(mean), # aggregate node-wise features into graph-wise features
6466
Dense(64, 1)) |> device;
6567

6668
julia> ps = Flux.params(model);
@@ -86,7 +88,7 @@ loss(loader) = mean(loss(g |> device) for g in loader)
8688

8789
for epoch in 1:100
8890
for g in train_loader
89-
g = g |> gpu
91+
g = g |> device
9092
grad = gradient(() -> loss(g), ps)
9193
Flux.Optimise.update!(opt, ps, grad)
9294
end

docs/src/messagepassing.md

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
# Message Passing
22

3-
The message passing is initiated by [`propagate`](@ref)
4-
and can be customized for a specific layer by overloading the methods
5-
[`compute_message`](@ref), [`update_node`](@ref), and [`update_edge`](@ref).
6-
7-
The message passing corresponds to the following operations
3+
The message passing is initiated by the [`propagate`](@ref) function
4+
and generally takes the form
85

96
```math
107
\begin{aligned}
118
\mathbf{m}_{j\to i} &= \phi(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{j\to i}) \\
129
\bar{\mathbf{m}}_{i} &= \square_{j\in N(i)} \mathbf{m}_{j\to i} \\
13-
1410
\mathbf{x}_{i}' &= \gamma_x(\mathbf{x}_{i}, \bar{\mathbf{m}}_{i})\\
1511
\mathbf{e}_{j\to i}^\prime &= \gamma_e(\mathbf{e}_{j \to i},\mathbf{m}_{j \to i})
1612
\end{aligned}
1713
```
18-
where ``\phi`` is expressed by the [`compute_message`](@ref) function,
19-
``\gamma_x`` and ``\gamma_e`` by [`update_node`](@ref) and [`update_edge`](@ref)
14+
15+
where we refer to ``\phi`` as to the message function,
16+
and to ``\gamma_x`` and ``\gamma_e`` as the node update and edge update function
2017
respectively. The generic aggregation ``\square`` usually is given by a summation
2118
``\sum``, a max or a mean operation.
2219

@@ -35,7 +32,6 @@ where ``c_{ij} = \sqrt{|N(i)||N(j)|}``. We will also add a bias and an activatio
3532

3633
```julia
3734
using Flux, LightGraphs, GraphNeuralNetworks
38-
import GraphNeuralNetworks: compute_message, update_node, propagate
3935

4036
struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
4137
weight::A
@@ -52,12 +48,10 @@ function GCN(ch::Pair{Int,Int}, σ=identity)
5248
GCN(W, b, σ)
5349
end
5450

55-
compute_message(l::GCN, xi, xj, eij) = l.weight * xj
56-
5751
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
5852
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5953
x = x .* c'
60-
x = propagate(l, g, +, x)
54+
x = propagate((xi, xj, e) -> l.weight * xj, g, +, xj=x)
6155
x = x .* c'
6256
return l.σ.(x .+ l.bias)
6357
end

docs/src/models.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@ In the explicit modeling style, the model is created according to the following
1515
1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields.
1616
2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
1717
3. Optionally define a convenience constructor for your model.
18-
4. Define the forward pass by implementing the function call method for your type
18+
4. Define the forward pass by implementing the call method for your type.
1919
5. Instantiate the model.
2020

2121
Here is an example of this construction:
2222
```julia
2323
using Flux, LightGraphs, GraphNeuralNetworks
24-
using Flux: @functor
2524

2625
struct GNN # step 1
2726
conv1
@@ -31,7 +30,7 @@ struct GNN # step 1
3130
dense
3231
end
3332

34-
@functor GNN # step 2
33+
Flux.@functor GNN # step 2
3534

3635
function GNN(din::Int, d::Int, dout::Int) # step 3
3736
GNN(GCNConv(din => d),
@@ -51,10 +50,12 @@ function (model::GNN)(g::GNNGraph, x) # step 4
5150
end
5251

5352
din, d, dout = 3, 4, 2
54-
g = GNNGraph(random_regular_graph(10, 4))
55-
X = randn(Float32, din, 10)
5653
model = GNN(din, d, dout) # step 5
57-
y = model(g, X)
54+
55+
g = GNNGraph(random_regular_graph(10, 4))
56+
X = randn(Float32, din, 10)
57+
y = model(g, X) # output size: (dout, g.num_nodes)
58+
gs = gradient(() -> sum(model(g, X)), Flux.params(model))
5859
```
5960

6061
## Implicit modeling with GNNChains
@@ -81,8 +82,6 @@ model = GNNChain(GCNConv(din => d),
8182
GCNConv(d => d, relu),
8283
Dropout(0.5),
8384
Dense(d, dout))
84-
85-
y = model(g, X) # output size: (dout, g.num_nodes)
8685
```
8786

8887
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export
3131
sprand, sparse,
3232

3333
# msgpass
34-
update_node, update_edge, compute_message, propagate,
34+
apply_edges, propagate,
3535

3636
# layers/basic
3737
GNNLayer,

src/deprecations.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,30 @@
22

33
@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)
44

5-
# TO Deprecate
6-
# x, _ = propagate(l, g, l.aggr, x, e)
75

8-
# # TODO deprecate
9-
# propagate(l, g::GNNGraph, aggr, x, e=nothing) = propagate(l, g, aggr; x, e)
6+
# Deprecated in v0.2
7+
# TODO check if argument order is exact
8+
function compute_message end
9+
function update_node end
10+
function update_edge end
1011

12+
compute_message(l, xi, xj, e) = compute_message(l, xi, xj)
13+
update_node(l, x, m̄) =
14+
update_edge(l, e, m) = e
15+
16+
function propagate(l::GNNLayer, g::GNNGraph, aggr, x, e=nothing)
17+
@warn """
18+
Passing a GNNLayer to propagate is deprecated,
19+
you should pass directly the message function.
20+
The new signature is `propagate(f, g, aggr; [xi, xj, e])`.
21+
22+
Also the functions `compute_message`, `update_node`,
23+
and `update_edge` have been deprecated. Please
24+
refer to the documentation.
25+
"""
26+
m = apply_edge((a...) -> compute_message(l, a...), g, x, x, e)
27+
= aggregate_neighbors(g, aggr, m)
28+
x = update_node(l, x, m̄)
29+
e = update_edge(l, e, m)
30+
return x, e
31+
end

src/gnngraph.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ from the LightGraphs' graph library can be used on it.
5050
- `:sparse`. A sparse adjacency matrix representation.
5151
- `:dense`. A dense adjacency matrix representation.
5252
Default `:coo`.
53-
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
53+
- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
5454
Possible values are `:out` and `:in`. Default `:out`.
55-
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
56-
- `graph_indicator`. For batched graphs, a vector containing the graph assigment of each node. Default `nothing`.
55+
- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`.
56+
- `graph_indicator`: For batched graphs, a vector containing the graph assigment of each node. Default `nothing`.
5757
- `ndata`: Node features. A named tuple of arrays whose last dimension has size num_nodes.
5858
- `edata`: Edge features. A named tuple of arrays whose whose last dimension has size num_edges.
5959
- `gdata`: Global features. A named tuple of arrays whose has size num_graphs.

0 commit comments

Comments
 (0)