Skip to content

Commit adb3cdb

Browse files
docs
1 parent b4494e1 commit adb3cdb

File tree

3 files changed

+75
-37
lines changed

3 files changed

+75
-37
lines changed

docs/src/messagepassing.md

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Message Passing
22

3-
The message passing is initiated by the [`propagate`](@ref) function
4-
and generally takes the form
3+
A generic message passing on graph takes the form
54

65
```math
76
\begin{aligned}
@@ -13,22 +12,41 @@ and generally takes the form
1312
```
1413

1514
where we refer to ``\phi`` as to the message function,
16-
and to ``\gamma_x`` and ``\gamma_e`` as the node update and edge update function
17-
respectively. The generic aggregation ``\square`` usually is given by a summation
18-
``\sum``, a max or a mean operation.
15+
and to ``\gamma_x`` and ``\gamma_e`` as to the node update and edge update function
16+
respectively. The aggregation ``\square`` is over the neighborhood ``N(i)`` of node ``i``,
17+
and it is usually set to summation ``\sum``, a max or a mean operation.
1918

20-
The message propagation mechanism internally relies on the [`NNlib.gather`](@ref)
19+
In GNN.jl, the function [`propagate`](@ref) takes care of materializing the
20+
node features on each edge, applying the message function, performing the
21+
aggregation, and returning ``\bar{\mathbf{m}}``.
22+
It is then left to the user to perform further node and edge updates,
23+
manypulating arrays of size ``D_{node} \times num_nodes`` and
24+
``D_{edge} \times num_edges``.
25+
26+
As part of the [`propagate`](@ref) pipeline, we have the function
27+
[`apply_edges`](@ref). It can be independently used to materialize
28+
node features on edges and perform edge-related computation without
29+
the following neighborhood aggregation one finds in `propagate`.
30+
31+
The whole propagation mechanism internally relies on the [`NNlib.gather`](@ref)
2132
and [`NNlib.scatter`](@ref) methods.
2233

23-
## An example: implementing the GCNConv
2434

25-
Let's (re-)implement the [`GCNConv`](@ref) layer use the message passing framework.
35+
## Examples
36+
37+
### Basic use propagate and apply_edges
38+
39+
40+
41+
### Implementing a custom Graph Convolutional Layer
42+
43+
Let's implement a simple graph convolutional layer using the message passing framework.
2644
The convolution reads
2745

2846
```math
29-
\mathbf{x}'_i = \sum_{j \in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
47+
\mathbf{x}'_i = W \cdot \sum_{j \in N(i)} \mathbf{x}_j
3048
```
31-
where ``c_{ij} = \sqrt{|N(i)||N(j)|}``. We will also add a bias and an activation function.
49+
We will also add a bias and an activation function.
3250

3351
```julia
3452
using Flux, LightGraphs, GraphNeuralNetworks
@@ -49,11 +67,18 @@ function GCN(ch::Pair{Int,Int}, σ=identity)
4967
end
5068

5169
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
52-
c = 1 ./ sqrt.(degree(g, T, dir=:in))
53-
x = x .* c'
54-
x = propagate((xi, xj, e) -> l.weight * xj, g, +, xj=x)
55-
x = x .* c'
56-
return l.σ.(x .+ l.bias)
70+
@assert size(x, 2) == g.num_nodes
71+
72+
# Computes messages from source/neighbour nodes (j) to target/root nodes (i).
73+
# The message function will have to handle matrices of size (*, num_edges).
74+
# In this simple case we just let the neighbor features go through.
75+
message(xi, xj, e) = xj
76+
77+
# The + operator gives the sum aggregation.
78+
# `mean`, `max`, `min`, and `*` are other possibilities.
79+
x = propagate(message, g, +, xj=x)
80+
81+
return l.σ.(l.weight * x .+ l.bias)
5782
end
5883
```
5984

docs/src/models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ model = GNN(din, d, dout) # step 5
5454

5555
g = GNNGraph(random_regular_graph(10, 4))
5656
X = randn(Float32, din, 10)
57+
5758
y = model(g, X) # output size: (dout, g.num_nodes)
5859
gs = gradient(() -> sum(model(g, X)), Flux.params(model))
5960
```

src/msgpass.jl

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,39 @@
11
"""
22
propagate(f, g, aggr; xi, xj, e) -> m̄
33
4-
Performs the message passing scheme on graph `g`.
5-
Returns the aggregated node features `m̄` computed
4+
Performs message passing on graph `g`.
65
7-
The computational steps are the following:
6+
Takes care of materializing the node features on each edge,
7+
applying the message function, and returning an aggregated message ``\bar{\mathbf{m}}``
8+
(depending on the return value of `f`, an array or a named tuple of
9+
arrays with last dimension's size `g.num_nodes`).
10+
11+
It can be decomposed in two steps:
812
913
```julia
1014
m = apply_edges(f, g, xi, xj, e)
1115
m̄ = aggregate_neighbors(g, aggr, m)
1216
```
1317
14-
GNN layers typically call propagate in their forward pass.
18+
GNN layers typically call `propagate` in their forward pass,
19+
providing as input `f` a closure.
1520
1621
# Arguments
1722
23+
- `g`: A `GNNGraph`.
24+
- `xi`: An array or a named tuple containing arrays whose last dimension's size
25+
is `g.num_nodes`. It will be appropriately materialized on the
26+
target node of each edge (see also [`edge_index`](@ref)).
27+
- `xj`: As `xj`, but to be materialized on edges' sources.
28+
- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`.
1829
- `f`: A generic function that will be passed over to [`apply_edges`](@ref).
19-
Takes as inputs `xi`, `xj`, and `e`
20-
(target nodes' features, source nodes' features, and edge features
21-
respetively) and returns new edge features `m`.
30+
Has to take as inputs the edge-materialized `xi`, `xj`, and `e`
31+
(arrays or named tuples of arrays whose last dimension' size is the size of
32+
a batch of edges). Its output has to be an array or a named tuple of arrays
33+
with the same batch size.
34+
- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`.
2235
23-
# Usage example
36+
# Usage Examples
2437
2538
```julia
2639
using GraphNeuralNetworks, Flux
@@ -68,29 +81,28 @@ end
6881
"""
6982
apply_edges(f, xi, xj, e)
7083
71-
Message function for the message-passing scheme
72-
started by [`propagate`](@ref).
7384
Returns the message from node `j` to node `i` .
7485
In the message-passing scheme, the incoming messages
7586
from the neighborhood of `i` will later be aggregated
7687
in order to update the features of node `i`.
7788
7889
The function operates on batches of edges, therefore
7990
`xi`, `xj`, and `e` are tensors whose last dimension
80-
is the batch size, or can be tuple/namedtuples of
81-
such tensors, according to the input to propagate.
82-
83-
By default, the function returns `xj`.
84-
Custom layer should specialize this method with the desired behavior.
85-
91+
is the batch size, or can be named tuples of
92+
such tensors.
93+
8694
# Arguments
8795
88-
- `f`: A function that takes as inputs `xi`, `xj`, and `e`
89-
(target nodes' features, source nodes' features, and edge features
90-
respetively) and returns new edge features `m`.
91-
- `xi`: Features of the central node `i`.
92-
- `xj`: Features of the neighbor `j` of node `i`.
93-
- `eij`: Features of edge `(i,j)`.
96+
- `g`: A `GNNGraph`.
97+
- `xi`: An array or a named tuple containing arrays whose last dimension's size
98+
is `g.num_nodes`. It will be appropriately materialized on the
99+
target node of each edge (see also [`edge_index`](@ref)).
100+
- `xj`: As `xj`, but to be materialized on edges' sources.
101+
- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`.
102+
- `f`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`.
103+
These are arrays (or named tuples of arrays) whose last dimension' size is the size of
104+
a batch of edges. The output of `f` has to be an array (or a named tuple of arrays)
105+
with the same batch size.
94106
95107
See also [`propagate`](@ref).
96108
"""

0 commit comments

Comments
 (0)