Skip to content

Commit 05cfb80

Browse files
Merge pull request #26 from CarloLucibello/cl/prop
Generalize message passing to handle Tuples and NamedTuples
2 parents 535d866 + f43086e commit 05cfb80

File tree

6 files changed

+109
-46
lines changed

6 files changed

+109
-46
lines changed

docs/src/messagepassing.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
11
# Message Passing
22

3-
TODO
3+
The message passing is initiated by [`propagate`](@ref)
4+
and can be customized for a specific layer by overloading the methods
5+
[`compute_message`](@ref), [`update_node`](@ref), and [`update_edge`](@ref).
46

7+
8+
The message passing corresponds to the following operations
9+
10+
```math
11+
\begin{aligned}
12+
\mathbf{m}_{j\to i} &= \phi(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{j\to i}) \\
13+
\mathbf{x}_{i}' &= \gamma_x(\mathbf{x}_{i}, \square_{j\in N(i)} \mathbf{m}_{j\to i})\\
14+
\mathbf{e}_{j\to i}^\prime &= \gamma_e(\mathbf{e}_{j \to i},\mathbf{m}_{j \to i})
15+
\end{aligned}
16+
```
17+
where ``phi`` is expressed by the [`compute_message`](@ref) function,
18+
``\gamma_x`` and ``gamma_v`` by [`update_node`](@ref) and [`update_edge`](@ref)
19+
respectively.
20+
21+
See [`GraphConv`](ref) and [`GATConv`](ref)'s implementations as usage examples.

src/gnngraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ from the LightGraphs' graph library can be used on it.
5959
6060
# Usage.
6161
62-
```
62+
```julia
6363
using Flux, GraphNeuralNetworks
6464
6565
# Construct from adjacency list representation

src/layers/conv.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ end
196196

197197

198198
@doc raw"""
199-
GATConv(in => out;
199+
GATConv(in => out, , σ=identity;
200200
heads=1,
201201
concat=true,
202202
init=glorot_uniform
@@ -228,6 +228,7 @@ struct GATConv{T, A<:AbstractMatrix{T}, B} <: GNNLayer
228228
weight::A
229229
bias::B
230230
a::A
231+
σ
231232
negative_slope::T
232233
channel::Pair{Int, Int}
233234
heads::Int
@@ -237,44 +238,43 @@ end
237238
@functor GATConv
238239
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
239240

240-
function GATConv(ch::Pair{Int,Int};
241+
function GATConv(ch::Pair{Int,Int}, σ=identity;
241242
heads::Int=1, concat::Bool=true, negative_slope=0.2f0,
242243
init=glorot_uniform, bias::Bool=true)
243244
in, out = ch
244245
W = init(out*heads, in)
245246
b = Flux.create_bias(W, bias, out*heads)
246247
a = init(2*out, heads)
247-
GATConv(W, b, a, negative_slope, ch, heads, concat)
248+
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
248249
end
249250

250-
function (gat::GATConv)(g::GNNGraph, X::AbstractMatrix)
251-
check_num_nodes(g, X)
251+
function compute_message(l::GATConv, Wxi, Wxj)
252+
aWW = sum(l.a .* cat(Wxi, Wxj, dims=1), dims=1) # 1 × nheads × nedges
253+
α = exp.(leakyrelu.(aWW, l.negative_slope))
254+
return= α, m = α .* Wxj)
255+
end
256+
257+
update_node(l::GATConv, d̄, x) =.m ./.α
258+
259+
function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
260+
check_num_nodes(g, x)
252261
g = add_self_loops(g)
253-
chin, chout = gat.channel
254-
heads = gat.heads
262+
chin, chout = l.channel
263+
heads = l.heads
255264

256-
source, target = edge_index(g)
257-
Wx = gat.weight*X
265+
Wx = l.weight * x
258266
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
259-
Wxi = NNlib.gather(Wx, target) # chout × nheads × nedges
260-
Wxj = NNlib.gather(Wx, source)
261-
262-
# Edge Message
263-
# Computing softmax. TODO make it numerically stable
264-
aWW = sum(gat.a .* cat(Wxi, Wxj, dims=1), dims=1) # 1 × nheads × nedges
265-
α = exp.(leakyrelu.(aWW, gat.negative_slope))
266-
= NNlib.scatter(+, α .* Wxj, target) # chout × nheads × nnodes
267-
ᾱ = NNlib.scatter(+, α, target) # 1 × nheads × nnodes
268267

269-
# Node update
270-
b = reshape(gat.bias, chout, heads)
271-
X =./ ᾱ .+ b # chout × nheads × nnodes
272-
if !gat.concat
273-
X = sum(X, dims=2)
268+
x, _ = propagate(l, g, +, Wx) ## chout × nheads × nnodes
269+
270+
b = reshape(l.bias, chout, heads)
271+
x = l.σ.(x .+ b)
272+
if !l.concat
273+
x = sum(x, dims=2)
274274
end
275275

276276
# We finally return a matrix
277-
return reshape(X, :, size(X, 3))
277+
return reshape(x, :, size(x, 3))
278278
end
279279

280280

src/msgpass.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ The computational steps are the following:
1313
1414
```julia
1515
m = compute_batch_message(l, g, x, e) # calls `compute_message`
16-
m̄ = aggregate_neighbors(l, aggr, g, m)
16+
m̄ = aggregate_neighbors(g, aggr, m)
1717
x′ = update_node(l, m̄, x)
1818
e′ = update_edge(l, m, e)
1919
```
@@ -24,7 +24,7 @@ this method in the forward pass:
2424
2525
# Usage example
2626
27-
```
27+
```julia
2828
using GraphNeuralNetworks, Flux
2929
3030
struct GNNConv <: GNNLayer
@@ -63,7 +63,7 @@ end
6363

6464
function propagate(l, g::GNNGraph, aggr, x, e=nothing)
6565
m = compute_batch_message(l, g, x, e)
66-
= aggregate_neighbors(l, g, aggr, m)
66+
= aggregate_neighbors(g, aggr, m)
6767
x′ = update_node(l, m̄, x)
6868
e′ = update_edge(l, m, e)
6969
return x′, e′
@@ -74,15 +74,17 @@ end
7474
"""
7575
compute_message(l, x_i, x_j, [e_ij])
7676
77-
Message function for the message-passing scheme,
78-
returning the message from node `j` to node `i` .
77+
Message function for the message-passing scheme
78+
started by [`propagate`](@ref).
79+
Returns the message from node `j` to node `i` .
7980
In the message-passing scheme, the incoming messages
8081
from the neighborhood of `i` will later be aggregated
8182
in order to update (see [`update_node`](@ref)) the features of node `i`.
8283
8384
The function operates on batches of edges, therefore
8485
`x_i`, `x_j`, and `e_ij` are tensors whose last dimension
85-
is the batch size.
86+
is the batch size, or can be tuple/namedtuples of
87+
such tensors, according to the input to propagate.
8688
8789
By default, the function returns `x_j`.
8890
Custom layer should specialize this method with the desired behavior.
@@ -101,10 +103,12 @@ function compute_message end
101103
@inline compute_message(l, x_i, x_j, e_ij) = compute_message(l, x_i, x_j)
102104
@inline compute_message(l, x_i, x_j) = x_j
103105

104-
_gather(x, i) = NNlib.gather(x, i)
106+
_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x)
107+
_gather(x::Tuple, i) = map(x -> _gather(x, i), x)
108+
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
105109
_gather(x::Nothing, i) = nothing
106110

107-
function compute_batch_message(l, g, x, e)
111+
function compute_batch_message(l, g::GNNGraph, x, e)
108112
s, t = edge_index(g)
109113
xi = _gather(x, t)
110114
xj = _gather(x, s)
@@ -114,12 +118,17 @@ end
114118

115119
## Step 2
116120

117-
function aggregate_neighbors(l, g, aggr, e)
121+
_scatter(aggr, e::NamedTuple, t) = map(e -> _scatter(aggr, e, t), e)
122+
_scatter(aggr, e::Tuple, t) = map(e -> _scatter(aggr, e, t), e)
123+
_scatter(aggr, e::AbstractArray, t) = NNlib.scatter(aggr, e, t)
124+
_scatter(aggr, e::Nothing, t) = nothing
125+
126+
function aggregate_neighbors(g::GNNGraph, aggr, e)
118127
s, t = edge_index(g)
119-
NNlib.scatter(aggr, e, t)
128+
_scatter(aggr, e, t)
120129
end
121130

122-
aggregate_neighbors(l, g, aggr::Nothing, e) = nothing
131+
aggregate_neighbors(g::GNNGraph, aggr::Nothing, e) = nothing
123132

124133
## Step 3
125134

@@ -130,6 +139,9 @@ Node update function for the GNN layer `l`,
130139
returning a new set of node features `x′` based on old
131140
features `x` and the aggregated message `m̄` from the neighborhood.
132141
142+
The input `m̄` is an array, a tuple or a named tuple,
143+
reflecting the output of [`compute_message`](@ref).
144+
133145
By default, the function returns `m̄`.
134146
Custom layers should specialize this method with the desired behavior.
135147
@@ -148,7 +160,7 @@ function update_node end
148160
Edge update function for the GNN layer `l`,
149161
returning a new set of edge features `e′` based on old
150162
features `e` and the newly computed messages `m`
151-
from the [`message`](@ref) function.
163+
from the [`compute_message`](@ref) function.
152164
153165
By default, the function returns `e`.
154166
Custom layers should specialize this method with the desired behavior.

test/cuda/msgpass.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010
0 1 0 1 0 1
1111
0 1 1 0 1 0]
1212

13-
struct NewCudaLayer
13+
struct NewCudaLayer{G} <: GNNLayer
1414
weight
1515
end
16-
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m, n))
17-
@functor NewCudaLayer
16+
NewCudaLayer{GRAPH_T}(m, n) = NewCudaLayer{GRAPH_T}(randn(T, m, n))
17+
Flux.@functor NewCudaLayer{GRAPH_T}
1818

19-
(l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)
20-
GraphNeuralNetworks.compute_message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
21-
GraphNeuralNetworks.update_node(::NewCudaLayer, m, x) = m
19+
(l::NewCudaLayer{GRAPH_T})(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)[1]
20+
GraphNeuralNetworks.compute_message(n::NewCudaLayer{GRAPH_T}, x_i, x_j, e_ij) = n.weight * x_j
21+
GraphNeuralNetworks.update_node(::NewCudaLayer{GRAPH_T}, m, x) = m
2222

2323
X = rand(T, in_channel, N) |> gpu
2424
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
25-
l = NewCudaLayer(out_channel, in_channel) |> gpu
25+
l = NewCudaLayer{GRAPH_T}(out_channel, in_channel) |> gpu
2626

2727
g_ = l(g)
2828
@test size(node_features(g_)) == (out_channel, N)

test/msgpass.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import GraphNeuralNetworks: compute_message, update_node, update_edge, propagate
2+
13
@testset "message passing" begin
24
in_channel = 10
35
out_channel = 5
@@ -113,9 +115,9 @@
113115
GraphNeuralNetworks.compute_message(l::NewLayerW{GRAPH_T}, x_i, x_j, e_ij) = l.weight * x_j
114116
GraphNeuralNetworks.update_node(l::NewLayerW{GRAPH_T}, m, x) = l.weight * x + m
115117

116-
l = NewLayerW(in_channel, out_channel)
117118
(l::NewLayerW{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +)
118119

120+
l = NewLayerW(in_channel, out_channel)
119121
g = GNNGraph(adj, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T)
120122
g_ = l(g)
121123

@@ -124,4 +126,36 @@
124126
@test edge_features(g_) === E
125127
@test graph_features(g_) === U
126128
end
129+
130+
@testset "NamedTuples" begin
131+
struct NewLayerNT{G}
132+
W
133+
end
134+
135+
NewLayerNT(in, out) = NewLayerNT{GRAPH_T}(randn(T, out, in))
136+
137+
function GraphNeuralNetworks.compute_message(l::NewLayerNT{GRAPH_T}, di, dj, dij)
138+
a = l.W * (di.x .+ dj.x .+ dij.e)
139+
b = l.W * di.x
140+
return (; a, b)
141+
end
142+
function GraphNeuralNetworks.update_node(l::NewLayerNT{GRAPH_T}, m, d)
143+
return=l.W * d.x + m.a + m.b, β=m)
144+
end
145+
function GraphNeuralNetworks.update_edge(l::NewLayerNT{GRAPH_T}, m, e)
146+
return m.a
147+
end
148+
149+
function (::NewLayerNT{GRAPH_T})(g, x, e)
150+
x, e = propagate(l, g, mean, (; x), (; e))
151+
return x.α .+ x.β.a, e
152+
end
153+
154+
l = NewLayerNT(in_channel, out_channel)
155+
g = GNNGraph(adj, graph_type=GRAPH_T)
156+
X′, E′ = l(g, X, E)
157+
158+
@test size(X′) == (out_channel, num_V)
159+
@test size(E′) == (out_channel, num_E)
160+
end
127161
end

0 commit comments

Comments
 (0)