Skip to content

Commit f43086e

Browse files
fix tests
1 parent d5e55eb commit f43086e

File tree

4 files changed

+47
-26
lines changed

4 files changed

+47
-26
lines changed

docs/src/messagepassing.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
11
# Message Passing
22

3-
TODO
3+
The message passing is initiated by [`propagate`](@ref)
4+
and can be customized for a specific layer by overloading the methods
5+
[`compute_message`](@ref), [`update_node`](@ref), and [`update_edge`](@ref).
46

7+
8+
The message passing corresponds to the following operations
9+
10+
```math
11+
\begin{aligned}
12+
\mathbf{m}_{j\to i} &= \phi(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{j\to i}) \\
13+
\mathbf{x}_{i}' &= \gamma_x(\mathbf{x}_{i}, \square_{j\in N(i)} \mathbf{m}_{j\to i})\\
14+
\mathbf{e}_{j\to i}^\prime &= \gamma_e(\mathbf{e}_{j \to i},\mathbf{m}_{j \to i})
15+
\end{aligned}
16+
```
17+
where ``phi`` is expressed by the [`compute_message`](@ref) function,
18+
``\gamma_x`` and ``gamma_v`` by [`update_node`](@ref) and [`update_edge`](@ref)
19+
respectively.
20+
21+
See [`GraphConv`](ref) and [`GATConv`](ref)'s implementations as usage examples.

src/msgpass.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ The computational steps are the following:
1313
1414
```julia
1515
m = compute_batch_message(l, g, x, e) # calls `compute_message`
16-
m̄ = aggregate_neighbors(l, aggr, g, m)
16+
m̄ = aggregate_neighbors(g, aggr, m)
1717
x′ = update_node(l, m̄, x)
1818
e′ = update_edge(l, m, e)
1919
```
@@ -63,7 +63,7 @@ end
6363

6464
function propagate(l, g::GNNGraph, aggr, x, e=nothing)
6565
m = compute_batch_message(l, g, x, e)
66-
= aggregate_neighbors(l, g, aggr, m)
66+
= aggregate_neighbors(g, aggr, m)
6767
x′ = update_node(l, m̄, x)
6868
e′ = update_edge(l, m, e)
6969
return x′, e′
@@ -74,15 +74,17 @@ end
7474
"""
7575
compute_message(l, x_i, x_j, [e_ij])
7676
77-
Message function for the message-passing scheme,
78-
returning the message from node `j` to node `i` .
77+
Message function for the message-passing scheme
78+
started by [`propagate`](@ref).
79+
Returns the message from node `j` to node `i` .
7980
In the message-passing scheme, the incoming messages
8081
from the neighborhood of `i` will later be aggregated
8182
in order to update (see [`update_node`](@ref)) the features of node `i`.
8283
8384
The function operates on batches of edges, therefore
8485
`x_i`, `x_j`, and `e_ij` are tensors whose last dimension
85-
is the batch size.
86+
is the batch size, or can be tuple/namedtuples of
87+
such tensors, according to the input to propagate.
8688
8789
By default, the function returns `x_j`.
8890
Custom layer should specialize this method with the desired behavior.
@@ -106,7 +108,7 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x)
106108
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
107109
_gather(x::Nothing, i) = nothing
108110

109-
function compute_batch_message(l, g, x, e)
111+
function compute_batch_message(l, g::GNNGraph, x, e)
110112
s, t = edge_index(g)
111113
xi = _gather(x, t)
112114
xj = _gather(x, s)
@@ -121,12 +123,12 @@ _scatter(aggr, e::Tuple, t) = map(e -> _scatter(aggr, e, t), e)
121123
_scatter(aggr, e::AbstractArray, t) = NNlib.scatter(aggr, e, t)
122124
_scatter(aggr, e::Nothing, t) = nothing
123125

124-
function aggregate_neighbors(l, g, aggr, e)
126+
function aggregate_neighbors(g::GNNGraph, aggr, e)
125127
s, t = edge_index(g)
126128
_scatter(aggr, e, t)
127129
end
128130

129-
aggregate_neighbors(l, g, aggr::Nothing, e) = nothing
131+
aggregate_neighbors(g::GNNGraph, aggr::Nothing, e) = nothing
130132

131133
## Step 3
132134

@@ -137,6 +139,9 @@ Node update function for the GNN layer `l`,
137139
returning a new set of node features `x′` based on old
138140
features `x` and the aggregated message `m̄` from the neighborhood.
139141
142+
The input `m̄` is an array, a tuple or a named tuple,
143+
reflecting the output of [`compute_message`](@ref).
144+
140145
By default, the function returns `m̄`.
141146
Custom layers should specialize this method with the desired behavior.
142147
@@ -155,7 +160,7 @@ function update_node end
155160
Edge update function for the GNN layer `l`,
156161
returning a new set of edge features `e′` based on old
157162
features `e` and the newly computed messages `m`
158-
from the [`message`](@ref) function.
163+
from the [`compute_message`](@ref) function.
159164
160165
By default, the function returns `e`.
161166
Custom layers should specialize this method with the desired behavior.

test/cuda/msgpass.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010
0 1 0 1 0 1
1111
0 1 1 0 1 0]
1212

13-
struct NewCudaLayer
13+
struct NewCudaLayer{G} <: GNNLayer
1414
weight
1515
end
16-
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m, n))
17-
@functor NewCudaLayer
16+
NewCudaLayer{GRAPH_T}(m, n) = NewCudaLayer{GRAPH_T}(randn(T, m, n))
17+
Flux.@functor NewCudaLayer{GRAPH_T}
1818

19-
(l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)
20-
GraphNeuralNetworks.compute_message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
21-
GraphNeuralNetworks.update_node(::NewCudaLayer, m, x) = m
19+
(l::NewCudaLayer{GRAPH_T})(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)[1]
20+
GraphNeuralNetworks.compute_message(n::NewCudaLayer{GRAPH_T}, x_i, x_j, e_ij) = n.weight * x_j
21+
GraphNeuralNetworks.update_node(::NewCudaLayer{GRAPH_T}, m, x) = m
2222

2323
X = rand(T, in_channel, N) |> gpu
2424
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
25-
l = NewCudaLayer(out_channel, in_channel) |> gpu
25+
l = NewCudaLayer{GRAPH_T}(out_channel, in_channel) |> gpu
2626

2727
g_ = l(g)
2828
@test size(node_features(g_)) == (out_channel, N)

test/msgpass.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,26 +132,25 @@ import GraphNeuralNetworks: compute_message, update_node, update_edge, propagate
132132
W
133133
end
134134

135-
NewLayerNT(in, out) = NewLayerW{GRAPH_T}(randn(T, out, in))
135+
NewLayerNT(in, out) = NewLayerNT{GRAPH_T}(randn(T, out, in))
136136

137-
function compute_message(l::NewLayerW{GRAPH_T}, di, dj, dij)
138-
a = l.W * (di.x .+ dj.x) + dij.e
137+
function GraphNeuralNetworks.compute_message(l::NewLayerNT{GRAPH_T}, di, dj, dij)
138+
a = l.W * (di.x .+ dj.x .+ dij.e)
139139
b = l.W * di.x
140140
return (; a, b)
141141
end
142-
function update_node(l::NewLayerW{GRAPH_T}, m, x)
143-
return=l.W * x + m.a + m.b, β=m)
142+
function GraphNeuralNetworks.update_node(l::NewLayerNT{GRAPH_T}, m, d)
143+
return=l.W * d.x + m.a + m.b, β=m)
144144
end
145-
function update_edge(l::NewLayerW{GRAPH_T}, m, e)
145+
function GraphNeuralNetworks.update_edge(l::NewLayerNT{GRAPH_T}, m, e)
146146
return m.a
147147
end
148148

149-
function (::NewLayerNT)(l, g, x, e)
150-
x, e = propagate(l, g, (; x), (; e))
149+
function (::NewLayerNT{GRAPH_T})(g, x, e)
150+
x, e = propagate(l, g, mean, (; x), (; e))
151151
return x.α .+ x.β.a, e
152152
end
153153

154-
155154
l = NewLayerNT(in_channel, out_channel)
156155
g = GNNGraph(adj, graph_type=GRAPH_T)
157156
X′, E′ = l(g, X, E)

0 commit comments

Comments
 (0)