Skip to content
111 changes: 111 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1689,3 +1689,114 @@ function Base.show(io::IO, l::TransformerConv)
(in, ein), out = l.channels
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
end

@doc raw"""
GraphormerLayer((in, ein) => out; [σ, heads, concat, negative_slope, init, add_self_loops, bias, phi])

A layer implementing the Graphormer model from the paper
["Do Transformers Really Perform Bad for Graph Representation?"](https://arxiv.org/abs/2106.05234),
which combines transformers and graph neural networks.

It applies a multi-head attention mechanism to node features and optionally edge features,
along with layer normalization and an activation function,
which makes it possible to capture both intra-node and inter-node dependencies.

The layer's forward pass is given by:
```math
h^{(0)} = Wx,
h^{(l+1)} = h^{(l)} + \mathrm{MHA}(\mathrm{LN}(h^{(l)})),
where W is a learnable weight matrix, x are the node features,
LN is layer normalization, and MHA is multi-head attention.

Arguments

in: Dimension of input node features.
ein: Dimension of the edge features; if 0, no edge features will be used.
out: Dimension of the output.
σ: Activation function to apply to the node features after the linear transformation.
Default identity.
heads: Number of heads in output. Default 1.
concat: Concatenate layer output or not. If not, layer output is averaged
over the heads. Default true.
negative_slope: The slope of the negative part of the LeakyReLU activation function.
Default 0.2.
init: Weight matrices' initializing function. Default glorot_uniform.
add_self_loops: Add self loops to the input graph. Default true.
bias: If set, the layer will also learn an additive bias for the nodes.
Default true.
phi: Number of Phi functions (Layer Normalization and Multi-Head Attention)
to be applied. Default 2.
"""
struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B, LN, MHA} <: GNNLayer
dense_x::DX
dense_e::DE
bias::B
a::A
σ::F
negative_slope::T
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
phi::Int
LN::LN
MHA::MHA
end

function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init = glorot_uniform, bias::Bool = true, add_self_loops = true, phi::Int = 2)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_x = Dense(in, out * heads, bias = false)
dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing
b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false
a = init(ein > 0 ? 3out : 2out, heads)
negative_slope = convert(Float32, negative_slope)
LN = LayerNorm(out)
MHA = MultiheadAttention(out, heads)
GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, phi, LN, MHA)
end

function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"

if l.add_self_loops
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
g = add_self_loops(g)
end

_, chout = l.channel
heads = l.heads
Wx = l.dense_x(x)
Wx = reshape(Wx, chout, heads, :)
h = Wx
for i in 1:l.phi
h = l.MHA(l.LN(h))
h = h + Wx
end
h = reshape(h, size(h, 1), heads * chout)
if l.bias !== false
h = l.bias(h)
end
h = l.σ.(h)
return h
end

function message(l::GraphormerLayer, xi, xj, e)
θ = cat(xi, xj, dims=2)
if l.dense_e !== nothing
fe = l.dense_e(e)
fe = reshape(fe, size(fe, 1), l.heads, :)
θ = cat(θ, fe, dims=1)
end
α = softmax(l.a' * θ) ## Not sure if I can use it from flux directly
W = α .* θ
W = l.σ.(W)
return W
end
38 changes: 38 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,41 @@ end
outsize = (in_channel, g.num_nodes))
end
end

@testset "GraphormerLayer" begin
heads = 2
concat = true
negative_slope = 0.2
add_self_loops = true
phi = 2
ch = (in_channel, 0) => out_channel
σ = relu
init = glorot_uniform
bias = true
l = GraphormerLayer(ch, σ; heads = heads, concat = concat,
negative_slope = negative_slope, init = init,
bias = bias, add_self_loops = add_self_loops, phi = phi)
for g in test_graphs
test_layer(l, g, rtol = RTOL_HIGH,
outsize = (concat ? heads * out_channel : out_channel, g.num_nodes))
end
l = GraphormerLayer(ch, σ; heads = heads, concat = concat,
negative_slope = negative_slope, init = init,
bias = bias, add_self_loops = false, phi = phi)
test_layer(l, g1, rtol = RTOL_HIGH,
outsize = (concat ? heads * out_channel : out_channel, g1.num_nodes))
ein = 3
ch = (in_channel, ein) => out_channel
l = GraphormerLayer(ch, σ; heads = heads, concat = concat,
negative_slope = negative_slope, init = init,
bias = bias, add_self_loops = add_self_loops, phi = phi)
g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges))
test_layer(l, g, rtol = RTOL_HIGH,
outsize = (concat ? heads * out_channel : out_channel, g.num_nodes))
l = GraphormerLayer(ch, σ; heads = heads, concat = concat,
negative_slope = negative_slope, init = init,
bias = false, add_self_loops = add_self_loops, phi = phi)
test_layer(l, g, rtol = RTOL_HIGH,
outsize = (concat ? heads * out_channel : out_channel, g.num_nodes))
@test length(Flux.params(l)) == (ein > 0 ? 7 : 6)
end