Skip to content

Commit 26a979e

Browse files
Add a transformer-like convolutional layer (#249)
* TransformerConv developed * Typo fixed * Typo fixed * Julia 1.7 line-breaks in strings avoided * Update src/layers/conv.jl Sure, thanks! Co-authored-by: Carlo Lucibello <[email protected]> * Update src/layers/conv.jl Co-authored-by: Carlo Lucibello <[email protected]> * Polishing of documentation, beta renamed to gated * Fix beta -> gating in tests Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 2dc0aee commit 26a979e

File tree

6 files changed

+251
-7
lines changed

6 files changed

+251
-7
lines changed

docs/src/api/conv.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ The table below lists all graph convolutional layers implemented in the *GraphNe
3232
| [`ResGatedGraphConv`](@ref) | | | |
3333
| [`SAGEConv`](@ref) || | |
3434
| [`SGConv`](@ref) || | |
35+
| [`TransformerConv`](@ref) | | ||
3536

3637

3738
## Docs

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export
6464
ResGatedGraphConv,
6565
SAGEConv,
6666
SGConv,
67+
TransformerConv,
6768

6869
# layers/pool
6970
GlobalPool,

src/layers/conv.jl

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,3 +1452,209 @@ function Base.show(io::IO, l::EGNNConv)
14521452
end
14531453
print(io, ")")
14541454
end
1455+
1456+
1457+
@doc raw"""
1458+
TransformerConv((in, ein) => out; [heads, concat, init, add_self_loops, bias_qkv,
1459+
bias_root, root_weight, gating, skip_connection, batch_norm, ff_channels]))
1460+
1461+
The transformer-like multi head attention convolutional operator from the
1462+
[Masked Label Prediction: Unified Message Passing Model for Semi-Supervised
1463+
Classification](https://arxiv.org/abs/2009.03509) paper, which also considers
1464+
edge features.
1465+
It further contains options to also be configured as the transformer-like convolutional operator from the
1466+
[Attention, Learn to Solve Routing Problems!](https://arxiv.org/abs/1706.03762) paper,
1467+
including a successive feed-forward network as well as skip layers and batch normalization.
1468+
1469+
The layer's basic forward pass is given by
1470+
```math
1471+
x_i' = W_1x_i + \sum_{j\in N(i)} \alpha_{ij} (W_2 x_j + W_6e_{ij})
1472+
```
1473+
where the attention scores are
1474+
```math
1475+
\alpha_{ij} = \mathrm{softmax}\left(\frac{(W_3x_i)^T(W_4x_j+
1476+
W_6e_{ij})}{\sqrt{d}}\right).
1477+
```
1478+
1479+
Optionally, a combination of the aggregated value with transformed root node features
1480+
by a gating mechanism via
1481+
```math
1482+
x'_i = \beta_i W_1 x_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
1483+
\alpha_{i,j} W_2 x_j \right)}_{=m_i}
1484+
```
1485+
with
1486+
```math
1487+
\beta_i = \textrm{sigmoid}(W_5^{\top} [ W_1 x_i, m_i, W_1 x_i - m_i ]).
1488+
```
1489+
can be performed.
1490+
1491+
# Arguments
1492+
1493+
- `in`: Dimension of input features, which also corresponds to the dimension of
1494+
the output features.
1495+
- `ein`: Dimension of the edge features; if 0, no edge features will be used.
1496+
- `out`: Dimension of the output.
1497+
- `heads`: Number of heads in output. Default `1`.
1498+
- `concat`: Concatenate layer output or not. If not, layer output is averaged
1499+
over the heads. Default `true`.
1500+
- `init`: Weight matrices' initializing function. Default `glorot_uniform`.
1501+
- `add_self_loops`: Add self loops to the input graph. Default `false`.
1502+
- `bias_qkv`: If set, bias is used in the key, query and value transformations for nodes.
1503+
Default `true`.
1504+
- `bias_root`: If set, the layer will also learn an additive bias for the root when root
1505+
weight is used. Default `true`.
1506+
- `root_weight`: If set, the layer will add the transformed root node features
1507+
to the output. Default `true`.
1508+
- `gating`: If set, will combine aggregation and transformed root node features by a
1509+
gating mechanism. Default `false`.
1510+
- `skip_connection`: If set, a skip connection will be made from the input and
1511+
added to the output. Default `false`.
1512+
- `batch_norm`: If set, a batch normalization will be applied to the output. Default `false`.
1513+
- `ff_channels`: If positive, a feed-forward NN is appended, with the first having the given
1514+
number of hidden nodes; this NN also gets a skip connection and batch normalization
1515+
if the respective parameters are set. Default: `0`.
1516+
"""
1517+
struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLayer
1518+
W1::TW1
1519+
W2::TW2
1520+
W3::TW3
1521+
W4::TW4
1522+
W5::TW5
1523+
W6::TW6
1524+
FF::TFF
1525+
BN1::TBN1
1526+
BN2::TBN2
1527+
channels::Pair{NTuple{2,Int},Int}
1528+
heads::Int
1529+
add_self_loops::Bool
1530+
concat::Bool
1531+
skip_connection::Bool
1532+
sqrt_out::Float32
1533+
end
1534+
1535+
@functor TransformerConv
1536+
1537+
Flux.trainable(l::TransformerConv) = (l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2)
1538+
1539+
TransformerConv(ch::Pair{Int,Int}, args...; kws...) = TransformerConv((ch[1], 0) => ch[2], args...; kws...)
1540+
1541+
function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
1542+
heads::Int = 1,
1543+
concat::Bool = true,
1544+
init = glorot_uniform,
1545+
add_self_loops::Bool = false,
1546+
bias_qkv = true,
1547+
bias_root::Bool = true,
1548+
root_weight::Bool = true,
1549+
gating::Bool = false,
1550+
skip_connection::Bool = false,
1551+
batch_norm::Bool = false,
1552+
ff_channels::Int = 0)
1553+
1554+
(in, ein), out = ch
1555+
1556+
if add_self_loops
1557+
@assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported."
1558+
end
1559+
1560+
W1 = root_weight ? Dense(in, out * (concat ? heads : 1); bias=bias_root, init=init) : nothing
1561+
W2 = Dense(in => out*heads; bias=bias_qkv, init=init)
1562+
W3 = Dense(in => out*heads; bias=bias_qkv, init=init)
1563+
W4 = Dense(in => out*heads; bias=bias_qkv, init=init)
1564+
out_mha = out * (concat ? heads : 1)
1565+
W5 = gating ? Dense(3 * out_mha => 1, sigmoid; bias=false, init=init) : nothing
1566+
W6 = ein > 0 ? Dense(ein => out*heads; bias=bias_qkv, init=init) : nothing
1567+
FF = ff_channels > 0 ? Chain(
1568+
Dense(out_mha => ff_channels, relu),
1569+
Dense(ff_channels => out_mha)
1570+
) : nothing
1571+
BN1 = batch_norm ? BatchNorm(out_mha) : nothing
1572+
BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing
1573+
1574+
return TransformerConv(W1, W2, W3, W4, W5, W6, FF, BN1, BN2,
1575+
ch, heads, add_self_loops, concat, skip_connection, Float32(out))
1576+
end
1577+
1578+
function (l::TransformerConv)(g::GNNGraph, x::AbstractMatrix,
1579+
e::Union{AbstractMatrix, Nothing}=nothing)
1580+
check_num_nodes(g, x)
1581+
1582+
if l.add_self_loops
1583+
g = add_self_loops(g)
1584+
end
1585+
1586+
out = l.channels[2]
1587+
heads = l.heads
1588+
W1x = !isnothing(l.W1) ? l.W1(x) : nothing
1589+
W2x = reshape(l.W2(x), out, heads, :)
1590+
W3x = reshape(l.W3(x), out, heads, :)
1591+
W4x = reshape(l.W4(x), out, heads, :)
1592+
W6e = !isnothing(l.W6) ? reshape(l.W6(e), out, heads, :) : nothing
1593+
1594+
m = apply_edges(message_uij, g, l; xi=(; W3x), xj=(; W4x), e=(; W6e))
1595+
α = softmax_edge_neighbors(g, m)
1596+
α_val = propagate(message_main, g, +, l; xi=(; W3x), xj=(; W2x), e=(; W6e, α))
1597+
1598+
h = α_val
1599+
if l.concat
1600+
h = reshape(h, out * heads, :) # concatenate heads
1601+
else
1602+
h = mean(h, dims=2) # average heads
1603+
h = reshape(h, out, :)
1604+
end
1605+
1606+
if !isnothing(W1x) # root_weight
1607+
if !isnothing(l.W5) # gating
1608+
β = l.W5(vcat(h, W1x, h .- W1x))
1609+
h = β .* W1x + (1f0 .- β) .* h
1610+
else
1611+
h += W1x
1612+
end
1613+
end
1614+
1615+
if l.skip_connection
1616+
@assert size(h, 1) == size(x, 1) "In-channels must correspond to out-channels * heads if skip_connection is used"
1617+
h += x
1618+
end
1619+
if !isnothing(l.BN1)
1620+
h = l.BN1(h)
1621+
end
1622+
1623+
if !isnothing(l.FF)
1624+
h1 = h
1625+
h = l.FF(h)
1626+
if l.skip_connection
1627+
h += h1
1628+
end
1629+
if !isnothing(l.BN2)
1630+
h = l.BN2(h)
1631+
end
1632+
end
1633+
1634+
return h
1635+
end
1636+
1637+
(l::TransformerConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g)))
1638+
1639+
function message_uij(l::TransformerConv, xi, xj, e)
1640+
key = xj.W4x
1641+
if !isnothing(e.W6e)
1642+
key += e.W6e
1643+
end
1644+
uij = sum(xi.W3x .* key, dims=1) ./ l.sqrt_out
1645+
return uij
1646+
end
1647+
1648+
function message_main(l::TransformerConv, xi, xj, e)
1649+
val = xj.W2x
1650+
if !isnothing(e.W6e)
1651+
val += e.W6e
1652+
end
1653+
return e.α .* val
1654+
end
1655+
1656+
function Base.show(io::IO, l::TransformerConv)
1657+
(in, ein), out = l.channels
1658+
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
1659+
end
1660+

test/examples/node_classification_cora.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ end
1616

1717
# arguments for the `train` function
1818
Base.@kwdef mutable struct Args
19-
η = 5f-3 # learning rate
19+
η = 5f-3 # learning rate
2020
epochs = 10 # number of epochs
21-
seed = 17 # set seed > 0 for reproducibility
22-
usecuda = false # if true use cuda (if available)
21+
seed = 17 # set seed > 0 for reproducibility
22+
usecuda = false # if true use cuda (if available)
2323
nhidden = 64 # dimension of hidden features
2424
end
2525

@@ -58,8 +58,8 @@ function train(Layer; verbose=false, kws...)
5858

5959
## TRAINING
6060
function report(epoch)
61-
train = eval_loss_accuracy(X, y, train_ids, model, g)
62-
test = eval_loss_accuracy(X, y, test_ids, model, g)
61+
train = eval_loss_accuracy(X, y, train_mask, model, g)
62+
test = eval_loss_accuracy(X, y, test_mask, model, g)
6363
println("Epoch: $epoch Train: $(train) Test: $(test)")
6464
end
6565

@@ -86,6 +86,8 @@ function train_many(; usecuda=false)
8686
("SAGEConv", (nin, nout) -> SAGEConv(nin => nout, relu)),
8787
("GATConv", (nin, nout) -> GATConv(nin => nout, relu)),
8888
("GINConv", (nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr=mean)),
89+
("TransformerConv", (nin, nout) -> TransformerConv(nin => nout, concat=false,
90+
add_self_loops=true, root_weight=false, heads=2))
8991
## ("ChebConv", (nin, nout) -> ChebConv(nin => nout, 2)), # not working on gpu
9092
## ("NNConv", (nin, nout) -> NNConv(nin => nout)), # needs edge features
9193
## ("GatedGraphConv", (nin, nout) -> GatedGraphConv(nout, 2)), # needs nin = nout
@@ -94,6 +96,7 @@ function train_many(; usecuda=false)
9496

9597
@show layer
9698
@time train_res, test_res = train(Layer; usecuda, verbose=false)
99+
# @show train_res, test_res
97100
@test train_res.acc > 94
98101
@test test_res.acc > 70
99102
end

test/layers/conv.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@
266266
@testset "GMMConv" begin
267267
ein_channel = 10
268268
K = 5
269-
l = GMMConv((in_channel, ein_channel )=> out_channel, K=K)
269+
l = GMMConv((in_channel, ein_channel ) => out_channel, K=K)
270270
for g in test_graphs
271271
g = GNNGraph(g, edata=rand(Float32, ein_channel, g.num_edges))
272272
test_layer(l, g, rtol=RTOL_HIGH, outsize = (out_channel, g.num_nodes))
@@ -300,4 +300,37 @@
300300
@test size(hnew) == (hout, g.num_nodes)
301301
@test size(xnew) == (in_channel, g.num_nodes)
302302
end
303+
304+
@testset "TransformerConv" begin
305+
ein = 2
306+
heads = 3
307+
# used like in Kool et al., 2019
308+
l = TransformerConv(in_channel * heads => in_channel; heads, add_self_loops=true,
309+
root_weight=false, ff_channels=10, skip_connection=true, batch_norm=false)
310+
# batch_norm=false here for tests to pass; true in paper
311+
for adj in [adj1, adj_single_vertex]
312+
g = GNNGraph(adj, ndata=rand(T, in_channel * heads, size(adj, 1)), graph_type=GRAPH_T)
313+
test_layer(l, g, rtol=RTOL_LOW,
314+
exclude_grad_fields = [:negative_slope],
315+
outsize=(in_channel * heads, g.num_nodes))
316+
end
317+
# used like in Shi et al., 2021
318+
l = TransformerConv((in_channel, ein) => in_channel; heads, gating=true, bias_qkv=true)
319+
for g in test_graphs
320+
g = GNNGraph(g, edata=rand(T, ein, g.num_edges))
321+
test_layer(l, g, rtol=RTOL_LOW,
322+
exclude_grad_fields = [:negative_slope],
323+
outsize=(in_channel * heads, g.num_nodes))
324+
end
325+
# test averaging heads
326+
l = TransformerConv(in_channel => in_channel; heads, concat=false, bias_root=false,
327+
root_weight=false)
328+
for g in test_graphs
329+
test_layer(l, g, rtol=RTOL_LOW,
330+
exclude_grad_fields = [:negative_slope],
331+
outsize=(in_channel, g.num_nodes))
332+
end
333+
end
303334
end
335+
336+

test/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818
# Tests also gradient on cpu and gpu comparing with
1919
# finite difference methods.
2020
# Test gradients with respects to layer weights and to input.
21-
# If `g` has edge features, it is assumed that the layer can be
21+
# If `g` has edge features, it is assumed that the layer can
2222
# use them in the forward pass as `l(g, x, e)`.
2323
# Test also gradient with repspect to `e`.
2424
function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,

0 commit comments

Comments
 (0)