diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c1b71cd4d..f16e691e5 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1689,3 +1689,114 @@ function Base.show(io::IO, l::TransformerConv) (in, ein), out = l.channels print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") end + +@doc raw""" + GraphormerLayer((in, ein) => out; [σ, heads, concat, negative_slope, init, add_self_loops, bias, phi]) + +A layer implementing the Graphormer model from the paper +["Do Transformers Really Perform Bad for Graph Representation?"](https://arxiv.org/abs/2106.05234), +which combines transformers and graph neural networks. + +It applies a multi-head attention mechanism to node features and optionally edge features, +along with layer normalization and an activation function, +which makes it possible to capture both intra-node and inter-node dependencies. + +The layer's forward pass is given by: +```math +h^{(0)} = Wx, +h^{(l+1)} = h^{(l)} + \mathrm{MHA}(\mathrm{LN}(h^{(l)})), +where W is a learnable weight matrix, x are the node features, +LN is layer normalization, and MHA is multi-head attention. + +Arguments + +in: Dimension of input node features. +ein: Dimension of the edge features; if 0, no edge features will be used. +out: Dimension of the output. +σ: Activation function to apply to the node features after the linear transformation. +Default identity. +heads: Number of heads in output. Default 1. +concat: Concatenate layer output or not. If not, layer output is averaged +over the heads. Default true. +negative_slope: The slope of the negative part of the LeakyReLU activation function. +Default 0.2. +init: Weight matrices' initializing function. Default glorot_uniform. +add_self_loops: Add self loops to the input graph. Default true. +bias: If set, the layer will also learn an additive bias for the nodes. +Default true. +phi: Number of Phi functions (Layer Normalization and Multi-Head Attention) +to be applied. Default 2. +""" +struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B, LN, MHA} <: GNNLayer + dense_x::DX + dense_e::DE + bias::B + a::A + σ::F + negative_slope::T + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + phi::Int + LN::LN + MHA::MHA +end + +function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; + heads::Int = 1, concat::Bool = true, negative_slope = 0.2, + init = glorot_uniform, bias::Bool = true, add_self_loops = true, phi::Int = 2) + (in, ein), out = ch + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_x = Dense(in, out * heads, bias = false) + dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing + b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false + a = init(ein > 0 ? 3out : 2out, heads) + negative_slope = convert(Float32, negative_slope) + LN = LayerNorm(out) + MHA = MultiheadAttention(out, heads) + GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, phi, LN, MHA) +end + +function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" + @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" + + if l.add_self_loops + @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." + g = add_self_loops(g) + end + + _, chout = l.channel + heads = l.heads + Wx = l.dense_x(x) + Wx = reshape(Wx, chout, heads, :) + h = Wx + for i in 1:l.phi + h = l.MHA(l.LN(h)) + h = h + Wx + end + h = reshape(h, size(h, 1), heads * chout) + if l.bias !== false + h = l.bias(h) + end + h = l.σ.(h) + return h +end + +function message(l::GraphormerLayer, xi, xj, e) + θ = cat(xi, xj, dims=2) + if l.dense_e !== nothing + fe = l.dense_e(e) + fe = reshape(fe, size(fe, 1), l.heads, :) + θ = cat(θ, fe, dims=1) + end + α = softmax(l.a' * θ) ## Not sure if I can use it from flux directly + W = α .* θ + W = l.σ.(W) + return W +end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index f4f2a9282..5631b1383 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -333,3 +333,41 @@ end outsize = (in_channel, g.num_nodes)) end end + +@testset "GraphormerLayer" begin + heads = 2 + concat = true + negative_slope = 0.2 + add_self_loops = true + phi = 2 + ch = (in_channel, 0) => out_channel + σ = relu + init = glorot_uniform + bias = true + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = bias, add_self_loops = add_self_loops, phi = phi) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g.num_nodes)) + end + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = bias, add_self_loops = false, phi = phi) + test_layer(l, g1, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g1.num_nodes)) + ein = 3 + ch = (in_channel, ein) => out_channel + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = bias, add_self_loops = add_self_loops, phi = phi) + g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges)) + test_layer(l, g, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g.num_nodes)) + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = false, add_self_loops = add_self_loops, phi = phi) + test_layer(l, g, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g.num_nodes)) + @test length(Flux.params(l)) == (ein > 0 ? 7 : 6) +end