Skip to content

Commit 95d0c50

Browse files
work
1 parent c18a71f commit 95d0c50

File tree

7 files changed

+152
-102
lines changed

7 files changed

+152
-102
lines changed

docs/src/api/messagepassing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
```@docs
44
GraphNeuralNetworks.message
5-
GraphNeuralNetworks.update
5+
GraphNeuralNetworks.update_node
66
GraphNeuralNetworks.propagate
77
```

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export
3131
sprand, sparse,
3232

3333
# msgpass
34-
# update, update_edge, message, propagate,
34+
# update_node, update_edge, message, propagate,
3535

3636
# layers/basic
3737
GNNLayer,

src/layers/conv.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
4545
l.σ.(l.weight * x *.+ l.bias)
4646
end
4747

48-
message(l::GCNConv, xi, xj) = xj
49-
update(l::GCNConv, x, m) = m
48+
message(l::GCNConv, xi, xj, eij) = xj
49+
update_node(l::GCNConv, m, x) = m
5050

5151
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
5252
g = add_self_loops(g)
@@ -177,11 +177,11 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
177177
end
178178

179179
message(l::GraphConv, x_i, x_j, e_ij) = x_j
180-
update(l::GraphConv, x, m) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
180+
update_node(l::GraphConv, m, x) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
181181

182182
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
183183
check_num_nodes(g, x)
184-
x, _ = propagate(l, g, +, x)
184+
x, _ = propagate(l, g, l.aggr, x)
185185
x
186186
end
187187

@@ -235,6 +235,7 @@ struct GATConv{T, A<:AbstractMatrix{T}, B} <: GNNLayer
235235
end
236236

237237
@functor GATConv
238+
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
238239

239240
function GATConv(ch::Pair{Int,Int};
240241
heads::Int=1, concat::Bool=true, negative_slope=0.2f0,
@@ -316,6 +317,7 @@ end
316317

317318
@functor GatedGraphConv
318319

320+
319321
function GatedGraphConv(out_ch::Int, num_layers::Int;
320322
aggr=+, init=glorot_uniform)
321323
w = init(out_ch, out_ch, num_layers)
@@ -325,7 +327,7 @@ end
325327

326328

327329
message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
328-
update(l::GatedGraphConv, x, m) = m
330+
update_node(l::GatedGraphConv, m, x) = m
329331

330332
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
331333
@non_differentiable fill!(x...)
@@ -340,7 +342,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:Abstra
340342
end
341343
for i = 1:l.num_layers
342344
M = view(l.weight, :, :, i) * H
343-
M, _ = propagate(l, g, +, M)
345+
M, _ = propagate(l, g, l.aggr, M)
344346
H, _ = l.gru(H, M)
345347
end
346348
H
@@ -381,11 +383,11 @@ EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
381383

382384
message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))
383385

384-
update(l::EdgeConv, x, m) = m
386+
update_node(l::EdgeConv, m, x) = m
385387

386388
function (l::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
387389
check_num_nodes(g, X)
388-
X, _ = propagate(l, g, +, X)
390+
X, _ = propagate(l, g, l.aggr, X)
389391
X
390392
end
391393

@@ -424,8 +426,8 @@ function GINConv(nn; eps=0f0)
424426
GINConv(nn, eps)
425427
end
426428

427-
message(l::GINConv, x_i, x_j) = x_j
428-
update(l::GINConv, x, m) = l.nn((1 + l.eps) * x + m)
429+
message(l::GINConv, x_i, x_j, e_ij) = x_j
430+
update_node(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
429431

430432
function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
431433
check_num_nodes(g, X)

src/msgpass.jl

Lines changed: 99 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,88 @@
1-
# Adapted message passing from paper
2-
# "Relational inductive biases, deep learning, and graph networks"
3-
41
"""
5-
propagate(l, g, aggr, [X, E]) -> X′, E
2+
propagate(l, g, aggr, [x, e]) -> x′, e
63
propagate(l, g, aggr) -> g′
74
8-
Perform the sequence of operations implementing the message-passing scheme
9-
of gnn layer `l` on graph `g` .
10-
Updates the node, edge, and global features `X`, `E`, and `U` respectively.
5+
Performs the message-passing for GNN layer `l` on graph `g` .
6+
Returns updated node and edge features `x` and `e`.
7+
8+
In case no input and edge features are given as input,
9+
extracts them from `g` and returns the same graph
10+
with updated feautres.
1111
12-
The computation involved is the following:
12+
The computational steps are the following:
1313
1414
```julia
15-
M = compute_batch_message(l, g, X, E)
16-
= aggregate_neighbors(l, aggr, g, M)
17-
X′ = update(l, X, M̄)
18-
E′ = update_edge(l, E, M)
15+
m = compute_batch_message(l, g, x, e) # calls `message`
16+
= aggregate_neighbors(l, aggr, g, m)
17+
x′ = update_node(l, m̄, x)
18+
e′ = update_edge(l, m, e)
1919
```
2020
21-
Custom layers typically define their own [`update`](@ref)
21+
Custom layers typically define their own [`update_node`](@ref)
2222
and [`message`](@ref) functions, then call
2323
this method in the forward pass:
2424
25-
```julia
26-
function (l::MyLayer)(g, X)
27-
... some prepocessing if needed ...
28-
propagate(l, g, +, X, E)
25+
# Usage example
26+
27+
```
28+
using GraphNeuralNetworks, Flux
29+
30+
struct GNNConv <: GNNLayer
31+
W
32+
b
33+
σ
34+
end
35+
36+
Flux.@functor GNNConv
37+
38+
function GNNConv(ch::Pair{Int,Int}, σ=identity;
39+
init=glorot_uniform, bias::Bool=true)
40+
in, out = ch
41+
W = init(out, in)
42+
b = Flux.create_bias(W, bias, out)
43+
GNNConv(W, b, σ, aggr)
44+
end
45+
46+
message(l::GNNConv, x_i, x_j, e_ij) = l.W * x_j
47+
update_node(l::GNNConv, m̄, x) = l.σ.(m̄ .+ l.bias)
48+
49+
function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix)
50+
x, _ = propagate(l, g, +, x)
51+
return x
2952
end
3053
```
3154
32-
See also [`message`](@ref) and [`update`](@ref).
55+
See also [`message`](@ref) and [`update_node`](@ref).
3356
"""
3457
function propagate end
3558

3659
function propagate(l, g::GNNGraph, aggr)
37-
X, E = propagate(l, g, aggr, node_features(g), edge_features(g))
38-
39-
return GNNGraph(g, ndata=X, edata=E)
60+
x, e = propagate(l, g, aggr, node_features(g), edge_features(g))
61+
return GNNGraph(g, ndata=x, edata=e)
4062
end
4163

42-
function propagate(l, g::GNNGraph, aggr, X, E=nothing)
43-
M = compute_batch_message(l, g, X, E)
44-
= aggregate_neighbors(l, g, aggr, M)
45-
X= update(l, X, M̄)
46-
E= update_edge(l, E, M)
47-
return X′, E′, U
64+
function propagate(l, g::GNNGraph, aggr, x, e=nothing)
65+
m = compute_batch_message(l, g, x, e)
66+
= aggregate_neighbors(l, g, aggr, m)
67+
x= update_node(l, m̄, x)
68+
e= update_edge(l, m, e)
69+
return x′, e
4870
end
4971

72+
## Step 1.
73+
5074
"""
5175
message(l, x_i, x_j, [e_ij])
5276
5377
Message function for the message-passing scheme,
5478
returning the message from node `j` to node `i` .
5579
In the message-passing scheme, the incoming messages
5680
from the neighborhood of `i` will later be aggregated
57-
in order to [`update`](@ref) the features of node `i`.
81+
in order to update (see [`update_node`](@ref)) the features of node `i`.
82+
83+
The function operates on batches of edges, therefore
84+
`x_i`, `x_j`, and `e_ij` are tensors whose last dimention
85+
is the batch size.
5886
5987
By default, the function returns `x_j`.
6088
Custom layer should specialize this method with the desired behavior.
@@ -66,63 +94,69 @@ Custom layer should specialize this method with the desired behavior.
6694
- `x_j`: Features of the neighbor `j` of node `i`.
6795
- `e_ij`: Features of edge `(i,j)`.
6896
69-
See also [`update`](@ref) and [`propagate`](@ref).
97+
See also [`update_node`](@ref) and [`propagate`](@ref).
7098
"""
7199
function message end
72100

73-
"""
74-
update(l, x, m̄)
75-
76-
Update function for the message-passing scheme,
77-
returning a new set of node features `x′` based on old
78-
features `x` and the incoming message from the neighborhood
79-
aggregation `m̄`.
80-
81-
By default, the function returns `m̄`.
82-
Custom layers should specialize this method with the desired behavior.
83-
84-
# Arguments
85-
86-
- `l`: A gnn layer.
87-
- `m̄`: Aggregated edge messages from the [`message`](@ref) function.
88-
- `x`: Node features to be updated.
89-
- `u`: Global features.
90-
91-
See also [`message`](@ref) and [`propagate`](@ref).
92-
"""
93-
function update end
101+
@inline message(l, x_i, x_j, e_ij) = message(l, x_i, x_j)
102+
@inline message(l, x_i, x_j) = x_j
94103

95104
_gather(x, i) = NNlib.gather(x, i)
96105
_gather(x::Nothing, i) = nothing
97106

98-
## Step 1.
99-
100-
function compute_batch_message(l, g, X, E)
107+
function compute_batch_message(l, g, x, e)
101108
s, t = edge_index(g)
102-
Xi = _gather(X, t)
103-
Xj = _gather(X, s)
104-
M = message(l, Xi, Xj, E)
105-
return M
109+
xi = _gather(x, t)
110+
xj = _gather(x, s)
111+
m = message(l, xi, xj, e)
112+
return m
106113
end
107114

108-
@inline message(l, x_i, x_j, e_ij) = message(l, x_i, x_j)
109-
@inline message(l, x_i, x_j) = x_j
110-
111115
## Step 2
112116

113-
function aggregate_neighbors(l, g, aggr, E)
117+
function aggregate_neighbors(l, g, aggr, e)
114118
s, t = edge_index(g)
115-
NNlib.scatter(aggr, E, t)
119+
NNlib.scatter(aggr, e, t)
116120
end
117121

118-
aggregate_neighbors(l, g, aggr::Nothing, E) = nothing
122+
aggregate_neighbors(l, g, aggr::Nothing, e) = nothing
119123

120124
## Step 3
121125

122-
@inline update(l, x, m̄) =
126+
"""
127+
update_node(l, m̄, x)
128+
129+
Node update function for the GNN layer `l`,
130+
returning a new set of node features `x′` based on old
131+
features `x` and the aggregated message `m̄` from the neighborhood.
132+
133+
By default, the function returns `m̄`.
134+
Custom layers should specialize this method with the desired behavior.
135+
136+
See also [`message`](@ref), [`update_edge`](@ref), and [`propagate`](@ref).
137+
"""
138+
function update_node end
139+
140+
@inline update_node(l, m̄, x) =
123141

124142
## Step 4
125143

126-
@inline update_edge(l, E, M) = E
144+
145+
"""
146+
update_edge(l, m, e)
147+
148+
Edge update function for the GNN layer `l`,
149+
returning a new set of edge features `e′` based on old
150+
features `e` and the newly computed messages `m`
151+
from the [`message`](@ref) function.
152+
153+
By default, the function returns `e`.
154+
Custom layers should specialize this method with the desired behavior.
155+
156+
See also [`message`](@ref), [`update_node`](@ref), and [`propagate`](@ref).
157+
"""
158+
function update_edge end
159+
160+
@inline update_edge(l, m, e) = e
127161

128162
### end steps ###

test/cuda/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
(l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)
2020
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
21-
GraphNeuralNetworks.update(::NewCudaLayer, x, m) = m
21+
GraphNeuralNetworks.update_node(::NewCudaLayer, m, x) = m
2222

2323
X = rand(T, in_channel, N) |> gpu
2424
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)

test/layers/conv.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
@test size(gat.weight) == (out_channel * heads, in_channel)
134134
@test size(gat.bias) == (out_channel * heads,)
135135
@test size(gat.a) == (2*out_channel, heads)
136+
@test length(Flux.trainable(gat)) == 3
136137

137138
g_ = gat(g_gat)
138139
Y = node_features(g_)

0 commit comments

Comments
 (0)