Skip to content
80 changes: 80 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,83 @@ function Base.show(io::IO, l::TransformerConv)
(in, ein), out = l.channels
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
end

@doc raw"""
GraphormerLayer constructor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the docstring style consistent with the other layers


Parameters:
- `ch`: A `Pair` object representing the input and output channels of the layer. The input channel should be a tuple of the form `(in_channels, num_edge_features)`, where `in_channels` is the number of input node features and `num_edge_features` is the number of input edge features. The output channel should be an integer representing the number of output features for each node.
- `σ`: The activation function to apply to the node features after the linear transformation. Defaults to `identity`.
- `heads`: The number of attention heads to use. Defaults to 1.
- `concat`: Whether to concatenate the output of each head or average them. Defaults to `true`.
- `negative_slope`: The slope of the negative part of the LeakyReLU activation function. Defaults to 0.2.
- `init`: The initialization function to use for the attention weights. Defaults to `glorot_uniform`.
- `bias`: Whether to include a bias term in the linear transformation. Defaults to `true`.
- `add_self_loops`: Whether to add self-loops to the graph. Defaults to `true`.

Example:
layer = GraphormerLayer((64, 32) => 128, σ = relu, heads = 4, concat = true, negative_slope = 0.1, init = xavier_uniform, bias = true, add_self_loops = false)
"""
struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer
dense_x::DX
dense_e::DE
bias::B
a::A
σ::F
negative_slope::T
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
end

@functor GraphormerLayer

Flux.trainable(l::GraphormerLayer) = (l.dense_x, l.dense_e, l.bias, l.a)
GraphormerLayer(ch::Pair{Int, Int}, args...; kws...) = GraphormerLayer((ch[1], 0) => ch[2], args...; kws...)

function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init = glorot_uniform, bias::Bool = true, add_self_loops = true)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_x = Dense(in, out * heads, bias = false)
dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing
b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false
a = init(ein > 0 ? 3out : 2out, heads)
negative_slope = convert(Float32, negative_slope)
GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops)
end

(l::GraphormerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))

function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix,
e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"

if l.add_self_loops
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
g = add_self_loops(g)
end

_, chout = l.channel
heads = l.heads

Wx = l.dense_x(x)
Wx = reshape(Wx, chout, heads, :)

# a hand-written message passing
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the message function defined?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies I missed writing it I added the comment to go back and add that let me get on to this

α = softmax_edge_neighbors(g, m.logα)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)

if !l.concat
x = mean(x, dims = 2)
end
end