Skip to content

Commit 1ef14a8

Browse files
Merge pull request #21 from CarloLucibello/cl/data2
change argument ordering in propagate and update
2 parents 1dc219c + b5512db commit 1ef14a8

File tree

4 files changed

+121
-124
lines changed

4 files changed

+121
-124
lines changed

src/layers/conv.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
4646
end
4747

4848
message(l::GCNConv, xi, xj) = xj
49-
update(l::GCNConv, m, x) = m
49+
update(l::GCNConv, x, m) = m
5050

5151
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
5252
g = add_self_loops(g)
5353
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5454
x = x .* c'
55-
_, x = propagate(l, g, nothing, x, nothing, +)
55+
x, _ = propagate(l, g, +, x)
5656
x = x .* c'
5757
return l.σ.(l.weight * x .+ l.bias)
5858
end
@@ -176,12 +176,12 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
176176
GraphConv(W1, W2, b, σ, aggr)
177177
end
178178

179-
message(gc::GraphConv, x_i, x_j, e_ij) = x_j
180-
update(gc::GraphConv, m, x) = gc.σ.(gc.weight1 * x .+ gc.weight2 * m .+ gc.bias)
179+
message(l::GraphConv, x_i, x_j, e_ij) = x_j
180+
update(l::GraphConv, x, m) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
181181

182-
function (gc::GraphConv)(g::GNNGraph, x::AbstractMatrix)
182+
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
183183
check_num_nodes(g, x)
184-
_, x = propagate(gc, g, nothing, x, nothing, +)
184+
x, _ = propagate(l, g, +, x)
185185
x
186186
end
187187

@@ -325,23 +325,23 @@ end
325325

326326

327327
message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
328-
update(l::GatedGraphConv, m, x) = m
328+
update(l::GatedGraphConv, x, m) = m
329329

330330
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
331331
@non_differentiable fill!(x...)
332332

333-
function (ggc::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
333+
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
334334
check_num_nodes(g, H)
335335
m, n = size(H)
336-
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
337-
if m < ggc.out_ch
338-
Hpad = similar(H, S, ggc.out_ch - m, n)
336+
@assert (m <= l.out_ch) "number of input features must less or equals to output features."
337+
if m < l.out_ch
338+
Hpad = similar(H, S, l.out_ch - m, n)
339339
H = vcat(H, fill!(Hpad, 0))
340340
end
341-
for i = 1:ggc.num_layers
342-
M = view(ggc.weight, :, :, i) * H
343-
_, M = propagate(ggc, g, nothing, M, nothing, +)
344-
H, _ = ggc.gru(H, M)
341+
for i = 1:l.num_layers
342+
M = view(l.weight, :, :, i) * H
343+
M, _ = propagate(l, g, +, M)
344+
H, _ = l.gru(H, M)
345345
end
346346
H
347347
end
@@ -379,13 +379,13 @@ end
379379

380380
EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
381381

382-
message(ec::EdgeConv, x_i, x_j, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
382+
message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))
383383

384-
update(ec::EdgeConv, m, x) = m
384+
update(l::EdgeConv, x, m) = m
385385

386-
function (ec::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
386+
function (l::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
387387
check_num_nodes(g, X)
388-
_, X = propagate(ec, g, nothing, X, nothing, ec.aggr)
388+
X, _ = propagate(l, g, +, X)
389389
X
390390
end
391391

@@ -425,10 +425,10 @@ function GINConv(nn; eps=0f0)
425425
end
426426

427427
message(l::GINConv, x_i, x_j) = x_j
428-
update(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
428+
update(l::GINConv, x, m) = l.nn((1 + l.eps) * x + m)
429429

430430
function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
431431
check_num_nodes(g, X)
432-
_, X = propagate(l, g, nothing, X, nothing, +)
432+
X, _ = propagate(l, g, +, X)
433433
X
434434
end

src/msgpass.jl

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,31 @@
22
# "Relational inductive biases, deep learning, and graph networks"
33

44
"""
5-
propagate(mp, g::GNNGraph, aggr)
6-
propagate(mp, g::GNNGraph, E, X, u, aggr)
5+
propagate(mp, g, aggr, [X, E, U]) -> X′, E′, U′
6+
propagate(mp, g, aggr) -> g′
77
8-
Perform the sequence of operation implementing the message-passing scheme
9-
and updating node, edge, and global features `X`, `E`, and `u` respectively.
8+
Perform the sequence of operations implementing the message-passing scheme
9+
of gnn layer `mp` on graph `g` .
10+
Updates the node, edge, and global features `X`, `E`, and `U` respectively.
1011
1112
The computation involved is the following:
1213
1314
```julia
14-
M = compute_batch_message(mp, g, E, X, u)
15-
E = update_edge(mp, M, E, u)
15+
M = compute_batch_message(mp, g, X, E, U)
1616
M̄ = aggregate_neighbors(mp, aggr, g, M)
17-
X = update(mp, M̄, X, u)
18-
u = update_global(mp, E, X, u)
17+
X′ = update(mp, X, M̄, U)
18+
E′ = update_edge(mp, E, M, U)
19+
U′ = update_global(mp, U, X′, E′)
1920
```
2021
2122
Custom layers typically define their own [`update`](@ref)
22-
and [`message`](@ref) function, then call
23+
and [`message`](@ref) functions, then call
2324
this method in the forward pass:
2425
2526
```julia
2627
function (l::MyLayer)(g, X)
2728
... some prepocessing if needed ...
28-
E = nothing
29-
u = nothing
30-
propagate(l, g, E, X, u, +)
29+
propagate(l, g, +, X, E, U)
3130
end
3231
```
3332
@@ -36,28 +35,28 @@ See also [`message`](@ref) and [`update`](@ref).
3635
function propagate end
3736

3837
function propagate(mp, g::GNNGraph, aggr)
39-
E, X, U = propagate(mp, g,
40-
edge_features(g), node_features(g), global_features(g),
41-
aggr)
42-
GNNGraph(g, ndata=X, edata=E, gdata=U)
38+
X, E, U = propagate(mp, g, aggr,
39+
node_features(g), edge_features(g), global_features(g))
40+
41+
return GNNGraph(g, ndata=X, edata=E, gdata=U)
4342
end
4443

45-
function propagate(mp, g::GNNGraph, E, X, U, aggr)
44+
function propagate(mp, g::GNNGraph, aggr, X, E=nothing, U=nothing)
4645
# TODO consider g.graph_indicator in propagating U
47-
M = compute_batch_message(mp, g, E, X, U)
48-
E = update_edge(mp, M, E, U)
49-
= aggregate_neighbors(mp, aggr, g, M)
50-
X = update(mp, M̄, X, U)
51-
U = update_global(mp, E, X, U)
52-
return E, X, U
46+
M = compute_batch_message(mp, g, X, E, U)
47+
= aggregate_neighbors(mp, g, aggr, M)
48+
X′ = update(mp, X, M̄, U)
49+
E′ = update_edge(mp, E, M, U)
50+
U = update_global(mp, U, X′, E′)
51+
return X′, E′, U
5352
end
5453

5554
"""
5655
message(mp, x_i, x_j, [e_ij, u])
5756
5857
Message function for the message-passing scheme,
5958
returning the message from node `j` to node `i` .
60-
In the message-passing scheme. the incoming messages
59+
In the message-passing scheme, the incoming messages
6160
from the neighborhood of `i` will later be aggregated
6261
in order to [`update`](@ref) the features of node `i`.
6362
@@ -77,7 +76,7 @@ See also [`update`](@ref) and [`propagate`](@ref).
7776
function message end
7877

7978
"""
80-
update(mp, m̄, x, [u])
79+
update(mp, x, m̄, [u])
8180
8281
Update function for the message-passing scheme,
8382
returning a new set of node features `x′` based on old
@@ -98,47 +97,45 @@ See also [`message`](@ref) and [`propagate`](@ref).
9897
"""
9998
function update end
10099

101-
102100
_gather(x, i) = NNlib.gather(x, i)
103101
_gather(x::Nothing, i) = nothing
104102

105103
## Step 1.
106104

107-
function compute_batch_message(mp, g, E, X, u)
105+
function compute_batch_message(mp, g, X, E, U)
108106
s, t = edge_index(g)
109107
Xi = _gather(X, t)
110108
Xj = _gather(X, s)
111-
M = message(mp, Xi, Xj, E, u)
109+
M = message(mp, Xi, Xj, E, U)
112110
return M
113111
end
114112

115-
# @inline message(mp, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future
116113
@inline message(mp, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij)
117114
@inline message(mp, x_i, x_j, e_ij) = message(mp, x_i, x_j)
118115
@inline message(mp, x_i, x_j) = x_j
119116

120-
## Step 2
121-
122-
@inline update_edge(mp, M, E, u) = update_edge(mp, M, E)
123-
@inline update_edge(mp, M, E) = E
124-
125-
## Step 3
117+
## Step 2
126118

127-
function aggregate_neighbors(mp, aggr, g, E)
119+
function aggregate_neighbors(mp, g, aggr, E)
128120
s, t = edge_index(g)
129121
NNlib.scatter(aggr, E, t)
130122
end
131123

132-
aggregate_neighbors(mp, aggr::Nothing, g, E) = nothing
124+
aggregate_neighbors(mp, g, aggr::Nothing, E) = nothing
125+
126+
## Step 3
127+
128+
@inline update(mp, x, m̄, u) = update(mp, x, m̄)
129+
@inline update(mp, x, m̄) =
133130

134131
## Step 4
135132

136-
# @inline update(mp, i, m̄, x, u) = update(mp, m, x, u)
137-
@inline update(mp, m̄, x, u) = update(mp, m̄, x)
138-
@inline update(mp, m̄, x) =
133+
@inline update_edge(mp, E, M, U) = update_edge(mp, E, M)
134+
@inline update_edge(mp, E, M) = E
139135

140136
## Step 5
141137

142-
@inline update_global(mp, E, X, u) = u
138+
@inline update_global(mp, U, X, E) = update_global(mp, U, X)
139+
@inline update_global(mp, U, X) = U
143140

144141
### end steps ###

test/cuda/msgpass.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
in_channel = 10
2-
out_channel = 5
3-
N = 6
4-
T = Float32
5-
adj = [0 1 0 0 0 0
6-
1 0 0 1 1 1
7-
0 0 0 0 0 1
8-
0 1 0 0 1 0
9-
0 1 0 1 0 1
10-
0 1 1 0 1 0]
1+
@testset "cuda/msgpass" begin
2+
in_channel = 10
3+
out_channel = 5
4+
N = 6
5+
T = Float32
6+
adj = [0 1 0 0 0 0
7+
1 0 0 1 1 1
8+
0 0 0 0 0 1
9+
0 1 0 0 1 0
10+
0 1 0 1 0 1
11+
0 1 1 0 1 0]
1112

12-
struct NewCudaLayer
13-
weight
14-
end
15-
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n))
16-
@functor NewCudaLayer
13+
struct NewCudaLayer
14+
weight
15+
end
16+
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m, n))
17+
@functor NewCudaLayer
1718

18-
(l::NewCudaLayer)(X) = GraphNeuralNetworks.propagate(l, X, +)
19-
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
20-
GraphNeuralNetworks.update(::NewCudaLayer, m, x) = m
19+
(l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)
20+
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
21+
GraphNeuralNetworks.update(::NewCudaLayer, x, m) = m
2122

22-
X = rand(T, in_channel, N) |> gpu
23-
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
24-
l = NewCudaLayer(out_channel, in_channel) |> gpu
23+
X = rand(T, in_channel, N) |> gpu
24+
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
25+
l = NewCudaLayer(out_channel, in_channel) |> gpu
2526

26-
@testset "cuda/msgpass" begin
2727
g_ = l(g)
2828
@test size(node_features(g_)) == (out_channel, N)
2929
end

0 commit comments

Comments
 (0)