Skip to content

Commit 40154c7

Browse files
init redesign
1 parent de5bb60 commit 40154c7

File tree

7 files changed

+77
-130
lines changed

7 files changed

+77
-130
lines changed

docs/src/messagepassing.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ The message passing corresponds to the following operations
99
```math
1010
\begin{aligned}
1111
\mathbf{m}_{j\to i} &= \phi(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{j\to i}) \\
12-
\mathbf{x}_{i}' &= \gamma_x(\mathbf{x}_{i}, \square_{j\in N(i)} \mathbf{m}_{j\to i})\\
12+
\bar{\mathbf{m}}_{i} &= \square_{j\in N(i)} \mathbf{m}_{j\to i} \\
13+
14+
\mathbf{x}_{i}' &= \gamma_x(\mathbf{x}_{i}, \bar{\mathbf{m}}_{i})\\
1315
\mathbf{e}_{j\to i}^\prime &= \gamma_e(\mathbf{e}_{j \to i},\mathbf{m}_{j \to i})
1416
\end{aligned}
1517
```
1618
where ``\phi`` is expressed by the [`compute_message`](@ref) function,
1719
``\gamma_x`` and ``\gamma_e`` by [`update_node`](@ref) and [`update_edge`](@ref)
18-
respectively.
20+
respectively. The generic aggregation ``\square`` usually is given by a summation
21+
``\sum``, a max or a mean operation.
1922

2023
The message propagation mechanism internally relies on the [`NNlib.gather`](@ref)
2124
and [`NNlib.scatter`](@ref) methods.
@@ -50,12 +53,11 @@ function GCN(ch::Pair{Int,Int}, σ=identity)
5053
end
5154

5255
compute_message(l::GCN, xi, xj, eij) = l.weight * xj
53-
update_node(l::GCN, m, x) = m
5456

5557
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
5658
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5759
x = x .* c'
58-
x, _ = propagate(l, g, +, x)
60+
x = propagate(l, g, +, x)
5961
x = x .* c'
6062
return l.σ.(x .+ l.bias)
6163
end

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,13 @@ export
5555
topk_index
5656

5757

58-
5958
include("gnngraph.jl")
6059
include("graph_conversions.jl")
6160
include("utils.jl")
62-
include("msgpass.jl")
6361
include("layers/basic.jl")
6462
include("layers/conv.jl")
6563
include("layers/pool.jl")
64+
include("msgpass.jl")
6665
include("deprecations.jl")
6766

6867
end

src/deprecations.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
# Deprecated in v0.1
22

33
@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)
4+
5+
# TO Deprecate
6+
# x, _ = propagate(l, g, l.aggr, x, e)

src/layers/conv.jl

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
4949
end
5050

5151
compute_message(l::GCNConv, xi, xj, eij) = xj
52-
update_node(l::GCNConv, m, x) = m
5352

5453
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
5554
if l.add_self_loops
5655
g = add_self_loops(g)
5756
end
5857
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5958
x = x .* c'
60-
x, _ = propagate(l, g, +, x)
59+
x = propagate(l, g, +, xj=x)
6160
x = x .* c'
6261
return l.σ.(l.weight * x .+ l.bias)
6362
end
@@ -182,12 +181,12 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity; aggr=+,
182181
end
183182

184183
compute_message(l::GraphConv, x_i, x_j, e_ij) = x_j
185-
update_node(l::GraphConv, m, x) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
186184

187185
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
188186
check_num_nodes(g, x)
189-
x, _ = propagate(l, g, l.aggr, x)
190-
x
187+
m = propagate(l, g, l.aggr, xj=x)
188+
x = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
189+
return x
191190
end
192191

193192
function Base.show(io::IO, l::GraphConv)
@@ -264,8 +263,6 @@ function compute_message(l::GATConv, Wxi, Wxj)
264263
return= α, m = α .* Wxj)
265264
end
266265

267-
update_node(l::GATConv, d̄, x) =.m ./.α
268-
269266
function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
270267
check_num_nodes(g, x)
271268
g = add_self_loops(g)
@@ -275,7 +272,8 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
275272
Wx = l.weight * x
276273
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
277274

278-
x, _ = propagate(l, g, +, Wx) ## chout × nheads × nnodes
275+
= propagate(l, g, +; x=Wx) ## chout × nheads × nnodes
276+
x =.m ./.α
279277

280278
if !l.concat
281279
x = mean(x, dims=2)
@@ -302,7 +300,7 @@ Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https:
302300
303301
Implements the recursion
304302
```math
305-
\mathbf{h}^{(0)}_i = \mathbf{x}_i || \mathbf{0} \\
303+
\mathbf{h}^{(0)}_i = [\mathbf{x}_i || \mathbf{0}] \\
306304
\mathbf{h}^{(l)}_i = GRU(\mathbf{h}^{(l-1)}_i, \square_{j \in N(i)} W \mathbf{h}^{(l-1)}_j)
307305
```
308306
@@ -325,17 +323,14 @@ end
325323

326324
@functor GatedGraphConv
327325

328-
329326
function GatedGraphConv(out_ch::Int, num_layers::Int;
330327
aggr=+, init=glorot_uniform)
331328
w = init(out_ch, out_ch, num_layers)
332329
gru = GRUCell(out_ch, out_ch)
333330
GatedGraphConv(w, gru, out_ch, num_layers, aggr)
334331
end
335332

336-
337333
compute_message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
338-
update_node(l::GatedGraphConv, m, x) = m
339334

340335
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
341336
@non_differentiable fill!(x...)
@@ -350,7 +345,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
350345
end
351346
for i = 1:l.num_layers
352347
M = view(l.weight, :, :, i) * H
353-
M, _ = propagate(l, g, l.aggr, M)
348+
M = propagate(l, g, l.aggr; xj=M)
354349
H, _ = l.gru(H, M)
355350
end
356351
H
@@ -391,12 +386,10 @@ EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
391386

392387
compute_message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))
393388

394-
update_node(l::EdgeConv, m, x) = m
395-
396-
function (l::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
389+
function (l::EdgeConv)(g::GNNGraph, x::AbstractMatrix)
397390
check_num_nodes(g, X)
398-
X, _ = propagate(l, g, l.aggr, X)
399-
X
391+
x = propagate(l, g, l.aggr; x)
392+
return x
400393
end
401394

402395
function Base.show(io::IO, l::EdgeConv)
@@ -433,25 +426,21 @@ Flux.trainable(l::GINConv) = (l.nn,)
433426

434427
GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)
435428

436-
437429
compute_message(l::GINConv, x_i, x_j, e_ij) = x_j
438-
update_node(l::GINConv, m, x) = l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
439430

440-
function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
441-
check_num_nodes(g, X)
442-
X, _ = propagate(l, g, l.aggr, X)
443-
X
431+
function (l::GINConv)(g::GNNGraph, x::AbstractMatrix)
432+
check_num_nodes(g, x)
433+
m = propagate(l, g, l.aggr, xj=x)
434+
l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
444435
end
445436

446-
447437
function Base.show(io::IO, l::GINConv)
448438
print(io, "GINConv($(l.nn)")
449439
print(io, ", $(l.ϵ)")
450440
print(io, ")")
451441
end
452442

453443

454-
455444
@doc raw"""
456445
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
457446
@@ -499,21 +488,17 @@ function NNConv(ch::Pair{Int,Int}, nn, σ=identity; aggr=+, bias=true, init=glor
499488
end
500489

501490
function compute_message(l::NNConv, x_i, x_j, e_ij)
502-
nin, nedges = size(x_i)
491+
nin, nedges = size(x_j)
503492
W = reshape(l.nn(e_ij), (:, nin, nedges))
504493
x_j = reshape(x_j, (nin, 1, nedges)) # needed by batched_mul
505494
m = NNlib.batched_mul(W, x_j)
506495
return reshape(m, :, nedges)
507496
end
508497

509-
function update_node(l::NNConv, m, x)
510-
l.σ.(l.weight*x .+ m .+ l.bias)
511-
end
512-
513498
function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e)
514499
check_num_nodes(g, x)
515-
x, _ = propagate(l, g, l.aggr, x, e)
516-
return x
500+
m = propagate(l, g, l.aggr, xj=x, e=e)
501+
return l.σ.(l.weight*x .+ m .+ l.bias)
517502
end
518503

519504
(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g)))
@@ -533,7 +518,7 @@ GraphSAGE convolution layer from paper [Inductive Representation Learning on Lar
533518
534519
Performs:
535520
```math
536-
\mathbf{x}_i' = W [\mathbf{x}_i \,\|\, \square_{j \in \mathcal{N}(i)} \mathbf{x}_j]
521+
\mathbf{x}_i' = W \cdot [\mathbf{x}_i \,\|\, \square_{j \in \mathcal{N}(i)} \mathbf{x}_j]
537522
```
538523
539524
where the aggregation type is selected by `aggr`.
@@ -565,12 +550,12 @@ function SAGEConv(ch::Pair{Int,Int}, σ=identity; aggr=mean,
565550
end
566551

567552
compute_message(l::SAGEConv, x_i, x_j, e_ij) = x_j
568-
update_node(l::SAGEConv, m, x) = l.σ.(l.weight * vcat(x, m) .+ l.bias)
569553

570554
function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
571555
check_num_nodes(g, x)
572-
x, _ = propagate(l, g, l.aggr, x)
573-
x
556+
m = propagate(l, g, l.aggr, xj=x)
557+
x = l.σ.(l.weight * vcat(x, m) .+ l.bias)
558+
return x
574559
end
575560

576561
function Base.show(io::IO, l::SAGEConv)
@@ -633,21 +618,18 @@ function compute_message(l::ResGatedGraphConv, di, dj)
633618
return η .* dj.Vx
634619
end
635620

636-
update_node(l::ResGatedGraphConv, m, x) = m
637-
638621
function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)
639622
check_num_nodes(g, x)
640623

641624
Ax = l.A * x
642625
Bx = l.B * x
643626
Vx = l.V * x
644627

645-
m, _ = propagate(l, g, +, (; Ax, Bx, Vx))
628+
m = propagate(l, g, +, xi=(; Ax), xj=(; Bx, Vx))
646629

647630
return l.σ.(l.U*x .+ m .+ l.bias)
648631
end
649632

650-
651633
function Base.show(io::IO, l::ResGatedGraphConv)
652634
out_channel, in_channel = size(l.A)
653635
print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel)

0 commit comments

Comments
 (0)