Skip to content

Adding Graphormer layer #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
80 changes: 80 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,83 @@ function Base.show(io::IO, l::TransformerConv)
(in, ein), out = l.channels
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
end

@doc raw"""
GraphormerLayer constructor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the docstring style consistent with the other layers


Parameters:
- `ch`: A `Pair` object representing the input and output channels of the layer. The input channel should be a tuple of the form `(in_channels, num_edge_features)`, where `in_channels` is the number of input node features and `num_edge_features` is the number of input edge features. The output channel should be an integer representing the number of output features for each node.
- `σ`: The activation function to apply to the node features after the linear transformation. Defaults to `identity`.
- `heads`: The number of attention heads to use. Defaults to 1.
- `concat`: Whether to concatenate the output of each head or average them. Defaults to `true`.
- `negative_slope`: The slope of the negative part of the LeakyReLU activation function. Defaults to 0.2.
- `init`: The initialization function to use for the attention weights. Defaults to `glorot_uniform`.
- `bias`: Whether to include a bias term in the linear transformation. Defaults to `true`.
- `add_self_loops`: Whether to add self-loops to the graph. Defaults to `true`.

Example:
layer = GraphormerLayer((64, 32) => 128, σ = relu, heads = 4, concat = true, negative_slope = 0.1, init = xavier_uniform, bias = true, add_self_loops = false)
"""
struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer
dense_x::DX
dense_e::DE
bias::B
a::A
σ::F
negative_slope::T
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
end

@functor GraphormerLayer

Flux.trainable(l::GraphormerLayer) = (l.dense_x, l.dense_e, l.bias, l.a)
GraphormerLayer(ch::Pair{Int, Int}, args...; kws...) = GraphormerLayer((ch[1], 0) => ch[2], args...; kws...)

function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init = glorot_uniform, bias::Bool = true, add_self_loops = true)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_x = Dense(in, out * heads, bias = false)
dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing
b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false
a = init(ein > 0 ? 3out : 2out, heads)
negative_slope = convert(Float32, negative_slope)
GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops)
end

(l::GraphormerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))

function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix,
e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
@assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor"

if l.add_self_loops
@assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."
g = add_self_loops(g)
end

_, chout = l.channel
heads = l.heads

Wx = l.dense_x(x)
Wx = reshape(Wx, chout, heads, :)

# a hand-written message passing
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the message function defined?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies I missed writing it I added the comment to go back and add that let me get on to this

α = softmax_edge_neighbors(g, m.logα)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)

if !l.concat
x = mean(x, dims = 2)
end
end