Skip to content

Commit b5512db

Browse files
change ordering in propagate/update
1 parent 4deb30e commit b5512db

File tree

4 files changed

+115
-116
lines changed

4 files changed

+115
-116
lines changed

src/layers/conv.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
4646
end
4747

4848
message(l::GCNConv, xi, xj) = xj
49-
update(l::GCNConv, m, x) = m
49+
update(l::GCNConv, x, m) = m
5050

5151
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
5252
g = add_self_loops(g)
5353
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5454
x = x .* c'
55-
_, x = propagate(l, g, nothing, x, nothing, +)
55+
x, _ = propagate(l, g, +, x)
5656
x = x .* c'
5757
return l.σ.(l.weight * x .+ l.bias)
5858
end
@@ -176,12 +176,12 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
176176
GraphConv(W1, W2, b, σ, aggr)
177177
end
178178

179-
message(gc::GraphConv, x_i, x_j, e_ij) = x_j
180-
update(gc::GraphConv, m, x) = gc.σ.(gc.weight1 * x .+ gc.weight2 * m .+ gc.bias)
179+
message(l::GraphConv, x_i, x_j, e_ij) = x_j
180+
update(l::GraphConv, x, m) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
181181

182-
function (gc::GraphConv)(g::GNNGraph, x::AbstractMatrix)
182+
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
183183
check_num_nodes(g, x)
184-
_, x = propagate(gc, g, nothing, x, nothing, +)
184+
x, _ = propagate(l, g, +, x)
185185
x
186186
end
187187

@@ -325,23 +325,23 @@ end
325325

326326

327327
message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
328-
update(l::GatedGraphConv, m, x) = m
328+
update(l::GatedGraphConv, x, m) = m
329329

330330
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
331331
@non_differentiable fill!(x...)
332332

333-
function (ggc::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
333+
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
334334
check_num_nodes(g, H)
335335
m, n = size(H)
336-
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
337-
if m < ggc.out_ch
338-
Hpad = similar(H, S, ggc.out_ch - m, n)
336+
@assert (m <= l.out_ch) "number of input features must less or equals to output features."
337+
if m < l.out_ch
338+
Hpad = similar(H, S, l.out_ch - m, n)
339339
H = vcat(H, fill!(Hpad, 0))
340340
end
341-
for i = 1:ggc.num_layers
342-
M = view(ggc.weight, :, :, i) * H
343-
_, M = propagate(ggc, g, nothing, M, nothing, +)
344-
H, _ = ggc.gru(H, M)
341+
for i = 1:l.num_layers
342+
M = view(l.weight, :, :, i) * H
343+
M, _ = propagate(l, g, +, M)
344+
H, _ = l.gru(H, M)
345345
end
346346
H
347347
end
@@ -379,13 +379,13 @@ end
379379

380380
EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
381381

382-
message(ec::EdgeConv, x_i, x_j, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
382+
message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))
383383

384-
update(ec::EdgeConv, m, x) = m
384+
update(l::EdgeConv, x, m) = m
385385

386-
function (ec::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
386+
function (l::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
387387
check_num_nodes(g, X)
388-
_, X = propagate(ec, g, nothing, X, nothing, ec.aggr)
388+
X, _ = propagate(l, g, +, X)
389389
X
390390
end
391391

@@ -425,10 +425,10 @@ function GINConv(nn; eps=0f0)
425425
end
426426

427427
message(l::GINConv, x_i, x_j) = x_j
428-
update(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
428+
update(l::GINConv, x, m) = l.nn((1 + l.eps) * x + m)
429429

430430
function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
431431
check_num_nodes(g, X)
432-
_, X = propagate(l, g, nothing, X, nothing, +)
432+
X, _ = propagate(l, g, +, X)
433433
X
434434
end

src/msgpass.jl

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# "Relational inductive biases, deep learning, and graph networks"
33

44
"""
5-
propagate(mp, g, X, E, U, aggr)
5+
propagate(mp, g, aggr, [X, E, U]) -> X′, E′, U′
6+
propagate(mp, g, aggr) -> g′
67
78
Perform the sequence of operations implementing the message-passing scheme
8-
on graph `g` with convolution layer `mp`.
9+
of gnn layer `mp` on graph `g` .
910
Updates the node, edge, and global features `X`, `E`, and `U` respectively.
1011
1112
The computation involved is the following:
@@ -14,7 +15,7 @@ The computation involved is the following:
1415
M = compute_batch_message(mp, g, X, E, U)
1516
M̄ = aggregate_neighbors(mp, aggr, g, M)
1617
X′ = update(mp, X, M̄, U)
17-
E′ = update_edge(mp, M, E, U)
18+
E′ = update_edge(mp, E, M, U)
1819
U′ = update_global(mp, U, X′, E′)
1920
```
2021
@@ -25,7 +26,7 @@ this method in the forward pass:
2526
```julia
2627
function (l::MyLayer)(g, X)
2728
... some prepocessing if needed ...
28-
propagate(l, g, X, E, U, +)
29+
propagate(l, g, +, X, E, U)
2930
end
3031
```
3132
@@ -34,28 +35,28 @@ See also [`message`](@ref) and [`update`](@ref).
3435
function propagate end
3536

3637
function propagate(mp, g::GNNGraph, aggr)
37-
X, E, U = propagate(mp, g,
38-
node_features(g), edge_features(g), global_features(g),
39-
aggr)
40-
GNNGraph(g, ndata=X, edata=E, gdata=U)
38+
X, E, U = propagate(mp, g, aggr,
39+
node_features(g), edge_features(g), global_features(g))
40+
41+
return GNNGraph(g, ndata=X, edata=E, gdata=U)
4142
end
4243

43-
function propagate(mp, g::GNNGraph, E, X, U, aggr)
44+
function propagate(mp, g::GNNGraph, aggr, X, E=nothing, U=nothing)
4445
# TODO consider g.graph_indicator in propagating U
45-
M = compute_batch_message(mp, g, E, X, U)
46-
E = update_edge(mp, M, E, U)
47-
= aggregate_neighbors(mp, aggr, g, M)
48-
X = update(mp, M̄, X, U)
49-
U = update_global(mp, E, X, U)
50-
return E, X, U
46+
M = compute_batch_message(mp, g, X, E, U)
47+
= aggregate_neighbors(mp, g, aggr, M)
48+
X′ = update(mp, X, M̄, U)
49+
E′ = update_edge(mp, E, M, U)
50+
U = update_global(mp, U, X′, E′)
51+
return X′, E′, U
5152
end
5253

5354
"""
5455
message(mp, x_i, x_j, [e_ij, u])
5556
5657
Message function for the message-passing scheme,
5758
returning the message from node `j` to node `i` .
58-
In the message-passing scheme. the incoming messages
59+
In the message-passing scheme, the incoming messages
5960
from the neighborhood of `i` will later be aggregated
6061
in order to [`update`](@ref) the features of node `i`.
6162
@@ -75,7 +76,7 @@ See also [`update`](@ref) and [`propagate`](@ref).
7576
function message end
7677

7778
"""
78-
update(mp, m̄, x, [u])
79+
update(mp, x, m̄, [u])
7980
8081
Update function for the message-passing scheme,
8182
returning a new set of node features `x′` based on old
@@ -96,47 +97,45 @@ See also [`message`](@ref) and [`propagate`](@ref).
9697
"""
9798
function update end
9899

99-
100100
_gather(x, i) = NNlib.gather(x, i)
101101
_gather(x::Nothing, i) = nothing
102102

103103
## Step 1.
104104

105-
function compute_batch_message(mp, g, E, X, u)
105+
function compute_batch_message(mp, g, X, E, U)
106106
s, t = edge_index(g)
107107
Xi = _gather(X, t)
108108
Xj = _gather(X, s)
109-
M = message(mp, Xi, Xj, E, u)
109+
M = message(mp, Xi, Xj, E, U)
110110
return M
111111
end
112112

113-
# @inline message(mp, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future
114113
@inline message(mp, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij)
115114
@inline message(mp, x_i, x_j, e_ij) = message(mp, x_i, x_j)
116115
@inline message(mp, x_i, x_j) = x_j
117116

118-
## Step 2
119-
120-
@inline update_edge(mp, M, E, u) = update_edge(mp, M, E)
121-
@inline update_edge(mp, M, E) = E
122-
123-
## Step 3
117+
## Step 2
124118

125-
function aggregate_neighbors(mp, aggr, g, E)
119+
function aggregate_neighbors(mp, g, aggr, E)
126120
s, t = edge_index(g)
127121
NNlib.scatter(aggr, E, t)
128122
end
129123

130-
aggregate_neighbors(mp, aggr::Nothing, g, E) = nothing
124+
aggregate_neighbors(mp, g, aggr::Nothing, E) = nothing
125+
126+
## Step 3
127+
128+
@inline update(mp, x, m̄, u) = update(mp, x, m̄)
129+
@inline update(mp, x, m̄) =
131130

132131
## Step 4
133132

134-
# @inline update(mp, i, m̄, x, u) = update(mp, m, x, u)
135-
@inline update(mp, m̄, x, u) = update(mp, m̄, x)
136-
@inline update(mp, m̄, x) =
133+
@inline update_edge(mp, E, M, U) = update_edge(mp, E, M)
134+
@inline update_edge(mp, E, M) = E
137135

138136
## Step 5
139137

140-
@inline update_global(mp, E, X, u) = u
138+
@inline update_global(mp, U, X, E) = update_global(mp, U, X)
139+
@inline update_global(mp, U, X) = U
141140

142141
### end steps ###

test/cuda/msgpass.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
in_channel = 10
2-
out_channel = 5
3-
N = 6
4-
T = Float32
5-
adj = [0 1 0 0 0 0
6-
1 0 0 1 1 1
7-
0 0 0 0 0 1
8-
0 1 0 0 1 0
9-
0 1 0 1 0 1
10-
0 1 1 0 1 0]
1+
@testset "cuda/msgpass" begin
2+
in_channel = 10
3+
out_channel = 5
4+
N = 6
5+
T = Float32
6+
adj = [0 1 0 0 0 0
7+
1 0 0 1 1 1
8+
0 0 0 0 0 1
9+
0 1 0 0 1 0
10+
0 1 0 1 0 1
11+
0 1 1 0 1 0]
1112

12-
struct NewCudaLayer
13-
weight
14-
end
15-
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n))
16-
@functor NewCudaLayer
13+
struct NewCudaLayer
14+
weight
15+
end
16+
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m, n))
17+
@functor NewCudaLayer
1718

18-
(l::NewCudaLayer)(X) = GraphNeuralNetworks.propagate(l, X, +)
19-
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
20-
GraphNeuralNetworks.update(::NewCudaLayer, m, x) = m
19+
(l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)
20+
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
21+
GraphNeuralNetworks.update(::NewCudaLayer, x, m) = m
2122

22-
X = rand(T, in_channel, N) |> gpu
23-
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
24-
l = NewCudaLayer(out_channel, in_channel) |> gpu
23+
X = rand(T, in_channel, N) |> gpu
24+
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
25+
l = NewCudaLayer(out_channel, in_channel) |> gpu
2526

26-
@testset "cuda/msgpass" begin
2727
g_ = l(g)
2828
@test size(node_features(g_)) == (out_channel, N)
2929
end

0 commit comments

Comments
 (0)