Skip to content
91 changes: 91 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,94 @@ function Base.show(io::IO, l::TransformerConv)
(in, ein), out = l.channels
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
end

@doc raw"""
GraphormerLayer constructor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the docstring style consistent with the other layers


Parameters:
- `ch`: A `Pair` object representing the input and output channels of the layer. The input channel should be a tuple of the form `(in_channels, num_edge_features)`, where `in_channels` is the number of input node features and `num_edge_features` is the number of input edge features. The output channel should be an integer representing the number of output features for each node.
- `σ`: The activation function to apply to the node features after the linear transformation. Defaults to `identity`.
- `heads`: The number of attention heads to use. Defaults to 1.
- `concat`: Whether to concatenate the output of each head or average them. Defaults to `true`.
- `negative_slope`: The slope of the negative part of the LeakyReLU activation function. Defaults to 0.2.
- `init`: The initialization function to use for the attention weights. Defaults to `glorot_uniform`.
- `bias`: Whether to include a bias term in the linear transformation. Defaults to `true`.
- `add_self_loops`: Whether to add self-loops to the graph. Defaults to `true`.

Example:
layer = GraphormerLayer((64, 32) => 128, σ = relu, heads = 4, concat = true, negative_slope = 0.1, init = xavier_uniform, bias = true, add_self_loops = false)
"""
struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer
dense_x::DX
dense_e::DE
bias::B
a::A
σ::F
negative_slope::T
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
phi::Function
end

@functor GraphormerLayer

Flux.trainable(l::GraphormerLayer) = (l.dense_x, l.dense_e, l.bias, l.a)

function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init = glorot_uniform, bias::Bool = true, add_self_loops = true)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_x = Dense(in, out * heads, bias = false)
dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing
b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false
a = init(ein > 0 ? 3out : 2out, heads)
negative_slope = convert(Float32, negative_slope)
GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, phi)
end

(l::GraphormerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))

function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix,
e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"

if l.add_self_loops
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
g = add_self_loops(g)
end

_, chout = l.channel
heads = l.heads
Wx = l.dense_x(x)
Wx = reshape(Wx, chout, heads, :)
h = Wx
for i in 1:l.phi
h = l.MHA(l.LN(h))
h = h + Wx
end
h = reshape(h, size(h, 1), heads * chout)
if l.bias !== false
h = l.bias(h)
end
h = l.σ.(h)
return h
end
function message(l::GraphormerLayer, xi, xj, e)
θ = cat(xi, xj, dims=2)
if l.dense_e !== nothing
fe = l.dense_e(e)
fe = reshape(fe, size(fe, 1), l.heads, :)
θ = cat(θ, fe, dims=1)
end
W = l.a * θ
W = l.σ.(W)
return W
end