From 3a853590aec0ec19fd5668dc77574f2fbb0b26a6 Mon Sep 17 00:00:00 2001 From: S Date: Mon, 13 Mar 2023 16:22:42 +0530 Subject: [PATCH 01/10] added Graphormer.jl --- src/Graphormer.jl | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/Graphormer.jl diff --git a/src/Graphormer.jl b/src/Graphormer.jl new file mode 100644 index 000000000..bcb9c90ae --- /dev/null +++ b/src/Graphormer.jl @@ -0,0 +1,38 @@ +using Flux +using Flux: @epochs, mse, params + +function Graphormer(input_dim, hidden_dim, output_dim, num_layers, num_heads) + + # Define the transformer encoder block + function TransformerEncoder(hidden_dim, num_heads) + multi_head_attention = Chain([Dense(hidden_dim, hidden_dim) for i in 1:num_heads]...) + layer_norm1 = LayerNorm(hidden_dim) + position_wise_feed_forward = Chain(Dense(hidden_dim, hidden_dim, relu), Dense(hidden_dim, hidden_dim)) + layer_norm2 = LayerNorm(hidden_dim) + + function (x) + # Calculate multi-head attention + heads = [head(x) for head in multi_head_attention] + concatenated = Flux.cat(heads..., dims=3) + attention_out = Flux.squeeze(sum(concatenated .* x, dims=2), dims=2) + attention_out = layer_norm1(x + attention_out) + + # Calculate position-wise feed forward network + ff_out = position_wise_feed_forward(attention_out) + ff_out = layer_norm2(attention_out + ff_out) + return ff_out + end + end + + input_embedding = Dense(input_dim, hidden_dim) + transformer_layers = Chain([TransformerEncoder(hidden_dim, num_heads) for i in 1:num_layers]...) + output_layer = Dense(hidden_dim, output_dim) + + function (x) + x = input_embedding(x) + x = transformer_layers(x) + x = mean(x, dims=1) + x = output_layer(x) + return x + end +end From a48e21028fd6b0d84b291624969d5bff6109e7ca Mon Sep 17 00:00:00 2001 From: S Date: Fri, 17 Mar 2023 12:14:44 +0530 Subject: [PATCH 02/10] used struct for graphomer --- src/Graphormer.jl | 38 --------------------------- src/layers/conv.jl | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 38 deletions(-) delete mode 100644 src/Graphormer.jl diff --git a/src/Graphormer.jl b/src/Graphormer.jl deleted file mode 100644 index bcb9c90ae..000000000 --- a/src/Graphormer.jl +++ /dev/null @@ -1,38 +0,0 @@ -using Flux -using Flux: @epochs, mse, params - -function Graphormer(input_dim, hidden_dim, output_dim, num_layers, num_heads) - - # Define the transformer encoder block - function TransformerEncoder(hidden_dim, num_heads) - multi_head_attention = Chain([Dense(hidden_dim, hidden_dim) for i in 1:num_heads]...) - layer_norm1 = LayerNorm(hidden_dim) - position_wise_feed_forward = Chain(Dense(hidden_dim, hidden_dim, relu), Dense(hidden_dim, hidden_dim)) - layer_norm2 = LayerNorm(hidden_dim) - - function (x) - # Calculate multi-head attention - heads = [head(x) for head in multi_head_attention] - concatenated = Flux.cat(heads..., dims=3) - attention_out = Flux.squeeze(sum(concatenated .* x, dims=2), dims=2) - attention_out = layer_norm1(x + attention_out) - - # Calculate position-wise feed forward network - ff_out = position_wise_feed_forward(attention_out) - ff_out = layer_norm2(attention_out + ff_out) - return ff_out - end - end - - input_embedding = Dense(input_dim, hidden_dim) - transformer_layers = Chain([TransformerEncoder(hidden_dim, num_heads) for i in 1:num_layers]...) - output_layer = Dense(hidden_dim, output_dim) - - function (x) - x = input_embedding(x) - x = transformer_layers(x) - x = mean(x, dims=1) - x = output_layer(x) - return x - end -end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index dc9dd79a1..82574a409 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1670,3 +1670,67 @@ function Base.show(io::IO, l::TransformerConv) (in, ein), out = l.channels print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") end + +struct GraphomerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer + dense_x::DX + dense_e::DE + bias::B + a::A + σ::F + negative_slope::T + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool +end + +@functor GraphomerLayer + +Flux.trainable(l::GraphomerLayer) = (l.dense_x, l.dense_e, l.bias, l.a) + +GraphomerLayer(ch::Pair{Int, Int}, args...; kws...) = GraphomerLayer((ch[1], 0) => ch[2], args...; kws...) + +function GraphomerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; + heads::Int = 1, concat::Bool = true, negative_slope = 0.2, + init = glorot_uniform, bias::Bool = true, add_self_loops = true) + (in, ein), out = ch + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_x = Dense(in, out * heads, bias = false) + dense_e = ein > 0 ? Dense(ein, out * heads, bias = false) : nothing + b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false + a = init(ein > 0 ? 3out : 2out, heads) + negative_slope = convert(Float32, negative_slope) + GraphomerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops) +end + +(l::GraphomerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) + +function (l::GraphomerLayer)(g::GNNGraph, x::AbstractMatrix, + e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" + @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" + + if l.add_self_loops + @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." + g = add_self_loops(g) + end + + _, chout = l.channel + heads = l.heads + + Wx = l.dense_x(x) + Wx = reshape(Wx, chout, heads, :) + + # a hand-written message passing + m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e) + α = softmax_edge_neighbors(g, m.logα) + β = α .* m.Wxj + x = aggregate_neighbors(g, +, β) + + if !l.concat + x = mean(x, dims = 2) + end From 89b71934005377d13df009c291f2ccd855cc32e5 Mon Sep 17 00:00:00 2001 From: S Date: Sun, 19 Mar 2023 08:57:40 +0530 Subject: [PATCH 03/10] added docstring --- src/layers/conv.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 82574a409..5ec758c48 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1671,6 +1671,22 @@ function Base.show(io::IO, l::TransformerConv) print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") end +@doc raw""" +GraphomerLayer constructor. + +Parameters: +- `ch`: A `Pair` object representing the input and output channels of the layer. The input channel should be a tuple of the form `(in_channels, num_edge_features)`, where `in_channels` is the number of input node features and `num_edge_features` is the number of input edge features. The output channel should be an integer representing the number of output features for each node. +- `σ`: The activation function to apply to the node features after the linear transformation. Defaults to `identity`. +- `heads`: The number of attention heads to use. Defaults to 1. +- `concat`: Whether to concatenate the output of each head or average them. Defaults to `true`. +- `negative_slope`: The slope of the negative part of the LeakyReLU activation function. Defaults to 0.2. +- `init`: The initialization function to use for the attention weights. Defaults to `glorot_uniform`. +- `bias`: Whether to include a bias term in the linear transformation. Defaults to `true`. +- `add_self_loops`: Whether to add self-loops to the graph. Defaults to `true`. + +Example: +layer = GraphomerLayer((64, 32) => 128, σ = relu, heads = 4, concat = true, negative_slope = 0.1, init = xavier_uniform, bias = true, add_self_loops = false) +""" struct GraphomerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer dense_x::DX dense_e::DE @@ -1734,3 +1750,4 @@ function (l::GraphomerLayer)(g::GNNGraph, x::AbstractMatrix, if !l.concat x = mean(x, dims = 2) end +end From 53efa7d9522e4b38a5262f2138916f230d5500cf Mon Sep 17 00:00:00 2001 From: S Date: Sun, 19 Mar 2023 09:04:53 +0530 Subject: [PATCH 04/10] changed struct to GraphormerLayer --- src/layers/conv.jl | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5ec758c48..2867557ec 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1672,7 +1672,7 @@ function Base.show(io::IO, l::TransformerConv) end @doc raw""" -GraphomerLayer constructor. +GraphormerLayer constructor. Parameters: - `ch`: A `Pair` object representing the input and output channels of the layer. The input channel should be a tuple of the form `(in_channels, num_edge_features)`, where `in_channels` is the number of input node features and `num_edge_features` is the number of input edge features. The output channel should be an integer representing the number of output features for each node. @@ -1685,9 +1685,9 @@ Parameters: - `add_self_loops`: Whether to add self-loops to the graph. Defaults to `true`. Example: -layer = GraphomerLayer((64, 32) => 128, σ = relu, heads = 4, concat = true, negative_slope = 0.1, init = xavier_uniform, bias = true, add_self_loops = false) +layer = GraphormerLayer((64, 32) => 128, σ = relu, heads = 4, concat = true, negative_slope = 0.1, init = xavier_uniform, bias = true, add_self_loops = false) """ -struct GraphomerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer +struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer dense_x::DX dense_e::DE bias::B @@ -1700,13 +1700,12 @@ struct GraphomerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: Abstract add_self_loops::Bool end -@functor GraphomerLayer +@functor GraphormerLayer -Flux.trainable(l::GraphomerLayer) = (l.dense_x, l.dense_e, l.bias, l.a) +Flux.trainable(l::GraphormerLayer) = (l.dense_x, l.dense_e, l.bias, l.a) +GraphormerLayer(ch::Pair{Int, Int}, args...; kws...) = GraphormerLayer((ch[1], 0) => ch[2], args...; kws...) -GraphomerLayer(ch::Pair{Int, Int}, args...; kws...) = GraphomerLayer((ch[1], 0) => ch[2], args...; kws...) - -function GraphomerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; +function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; heads::Int = 1, concat::Bool = true, negative_slope = 0.2, init = glorot_uniform, bias::Bool = true, add_self_loops = true) (in, ein), out = ch @@ -1719,12 +1718,12 @@ function GraphomerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false a = init(ein > 0 ? 3out : 2out, heads) negative_slope = convert(Float32, negative_slope) - GraphomerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops) + GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops) end -(l::GraphomerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) +(l::GraphormerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) -function (l::GraphomerLayer)(g::GNNGraph, x::AbstractMatrix, +function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" From 6aa987d0dc35c2c84b2d8c3f0f73a87d4883c031 Mon Sep 17 00:00:00 2001 From: S Date: Wed, 22 Mar 2023 10:38:00 +0530 Subject: [PATCH 05/10] added tests --- test/layers/conv.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 72f413725..8ead70217 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -334,3 +334,46 @@ end end end + +@testset "Graphormer" begin + @testset "Initialization" begin + n_layers = 3 + d_model = 64 + n_heads = 8 + d_ff = 256 + dropout = 0.1 + g = Graphormer(n_layers, d_model, n_heads, d_ff, dropout) + @test typeof(g) == Graphormer + @test length(g.layers) == n_layers + end + + @testset "Forward pass" begin + n_nodes = 16 + n_edges = 32 + n_feats = 32 + n_classes = 10 + g = rand_graph(n_nodes, n_edges) + x = randn(Float32, n_feats, n_nodes) + y = rand(1:n_classes, n_nodes) + model = Graphormer(3, 64, 8, 256, 0.1) + out = model(g, x) + @test size(out) == (n_classes, n_nodes) + end + + @testset "Backward pass" begin + n_nodes = 16 + n_edges = 32 + n_feats = 32 + n_classes = 10 + g = rand_graph(n_nodes, n_edges) + x = randn(Float32, n_feats, n_nodes) + y = rand(1:n_classes, n_nodes) + model = Graphormer(3, 64, 8, 256, 0.1) + out = model(g, x) + loss = Flux.crossentropy(out, y) + Flux.back!(loss) + for p in params(model) + @test size(grad(p)) == size(p) + end + end +end From 99813a6377e070574a4f16da7c8551ebb1d507bf Mon Sep 17 00:00:00 2001 From: S Date: Thu, 23 Mar 2023 01:07:10 +0530 Subject: [PATCH 06/10] Update src/layers/conv.jl Co-authored-by: Carlo Lucibello --- src/layers/conv.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ea762a5c3..5b80a5137 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1721,6 +1721,7 @@ end @functor GraphormerLayer Flux.trainable(l::GraphormerLayer) = (l.dense_x, l.dense_e, l.bias, l.a) + GraphormerLayer(ch::Pair{Int, Int}, args...; kws...) = GraphormerLayer((ch[1], 0) => ch[2], args...; kws...) function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; From fac37ca1c621d997dab557a12195961809acf0cb Mon Sep 17 00:00:00 2001 From: S Date: Tue, 28 Mar 2023 05:54:42 +0530 Subject: [PATCH 07/10] updated eq 5 --- src/layers/conv.jl | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5b80a5137..e3c68e870 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1716,14 +1716,13 @@ struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: Abstrac heads::Int concat::Bool add_self_loops::Bool + phi::Function end @functor GraphormerLayer Flux.trainable(l::GraphormerLayer) = (l.dense_x, l.dense_e, l.bias, l.a) -GraphormerLayer(ch::Pair{Int, Int}, args...; kws...) = GraphormerLayer((ch[1], 0) => ch[2], args...; kws...) - function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; heads::Int = 1, concat::Bool = true, negative_slope = 0.2, init = glorot_uniform, bias::Bool = true, add_self_loops = true) @@ -1737,7 +1736,7 @@ function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false a = init(ein > 0 ? 3out : 2out, heads) negative_slope = convert(Float32, negative_slope) - GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops) + GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, phi) end (l::GraphormerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) @@ -1755,17 +1754,15 @@ function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, _, chout = l.channel heads = l.heads - Wx = l.dense_x(x) Wx = reshape(Wx, chout, heads, :) - - # a hand-written message passing m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e) α = softmax_edge_neighbors(g, m.logα) β = α .* m.Wxj x = aggregate_neighbors(g, +, β) - if !l.concat x = mean(x, dims = 2) + else + x = reshape(x, size(x, 1), heads * end -end +end \ No newline at end of file From ba9333cabdf80b4f4e9b8dd7fee268cdb66adeb6 Mon Sep 17 00:00:00 2001 From: S Date: Tue, 28 Mar 2023 06:01:23 +0530 Subject: [PATCH 08/10] added message function --- src/layers/conv.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e3c68e870..e6d8163d4 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1765,4 +1765,15 @@ function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, else x = reshape(x, size(x, 1), heads * end +end +function message(l::GraphormerLayer, xi, xj, e) + θ = cat(xi, xj, dims=2) + if l.dense_e !== nothing + fe = l.dense_e(e) + fe = reshape(fe, size(fe, 1), l.heads, :) + θ = cat(θ, fe, dims=1) + end + W = l.a * θ + W = l.σ.(W) + return W end \ No newline at end of file From 738775229b7b94cea80b679805f7134abf24c5d9 Mon Sep 17 00:00:00 2001 From: S Date: Fri, 7 Apr 2023 01:20:17 +0530 Subject: [PATCH 09/10] updates on eq7 --- src/layers/conv.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e6d8163d4..ed7996931 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1756,15 +1756,17 @@ function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, heads = l.heads Wx = l.dense_x(x) Wx = reshape(Wx, chout, heads, :) - m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e) - α = softmax_edge_neighbors(g, m.logα) - β = α .* m.Wxj - x = aggregate_neighbors(g, +, β) - if !l.concat - x = mean(x, dims = 2) - else - x = reshape(x, size(x, 1), heads * + h = Wx + for i in 1:l.phi + h = l.MHA(l.LN(h)) + h = h + Wx end + h = reshape(h, size(h, 1), heads * chout) + if l.bias !== false + h = l.bias(h) + end + h = l.σ.(h) + return h end function message(l::GraphormerLayer, xi, xj, e) θ = cat(xi, xj, dims=2) From b8b0af72380d30167c5b1e9298f78b896fd05ed3 Mon Sep 17 00:00:00 2001 From: S Date: Sun, 25 Jun 2023 20:57:58 +0530 Subject: [PATCH 10/10] added tests+improved on the code --- src/layers/conv.jl | 76 ++++++++++++++++++++++++++++----------------- test/layers/conv.jl | 38 +++++++++++++++++++++++ 2 files changed, 86 insertions(+), 28 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 27b0e1d7e..f16e691e5 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1691,22 +1691,43 @@ function Base.show(io::IO, l::TransformerConv) end @doc raw""" -GraphormerLayer constructor. - -Parameters: -- `ch`: A `Pair` object representing the input and output channels of the layer. The input channel should be a tuple of the form `(in_channels, num_edge_features)`, where `in_channels` is the number of input node features and `num_edge_features` is the number of input edge features. The output channel should be an integer representing the number of output features for each node. -- `σ`: The activation function to apply to the node features after the linear transformation. Defaults to `identity`. -- `heads`: The number of attention heads to use. Defaults to 1. -- `concat`: Whether to concatenate the output of each head or average them. Defaults to `true`. -- `negative_slope`: The slope of the negative part of the LeakyReLU activation function. Defaults to 0.2. -- `init`: The initialization function to use for the attention weights. Defaults to `glorot_uniform`. -- `bias`: Whether to include a bias term in the linear transformation. Defaults to `true`. -- `add_self_loops`: Whether to add self-loops to the graph. Defaults to `true`. - -Example: -layer = GraphormerLayer((64, 32) => 128, σ = relu, heads = 4, concat = true, negative_slope = 0.1, init = xavier_uniform, bias = true, add_self_loops = false) + GraphormerLayer((in, ein) => out; [σ, heads, concat, negative_slope, init, add_self_loops, bias, phi]) + +A layer implementing the Graphormer model from the paper +["Do Transformers Really Perform Bad for Graph Representation?"](https://arxiv.org/abs/2106.05234), +which combines transformers and graph neural networks. + +It applies a multi-head attention mechanism to node features and optionally edge features, +along with layer normalization and an activation function, +which makes it possible to capture both intra-node and inter-node dependencies. + +The layer's forward pass is given by: +```math +h^{(0)} = Wx, +h^{(l+1)} = h^{(l)} + \mathrm{MHA}(\mathrm{LN}(h^{(l)})), +where W is a learnable weight matrix, x are the node features, +LN is layer normalization, and MHA is multi-head attention. + +Arguments + +in: Dimension of input node features. +ein: Dimension of the edge features; if 0, no edge features will be used. +out: Dimension of the output. +σ: Activation function to apply to the node features after the linear transformation. +Default identity. +heads: Number of heads in output. Default 1. +concat: Concatenate layer output or not. If not, layer output is averaged +over the heads. Default true. +negative_slope: The slope of the negative part of the LeakyReLU activation function. +Default 0.2. +init: Weight matrices' initializing function. Default glorot_uniform. +add_self_loops: Add self loops to the input graph. Default true. +bias: If set, the layer will also learn an additive bias for the nodes. +Default true. +phi: Number of Phi functions (Layer Normalization and Multi-Head Attention) +to be applied. Default 2. """ -struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B} <: GNNLayer +struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, F, B, LN, MHA} <: GNNLayer dense_x::DX dense_e::DE bias::B @@ -1717,16 +1738,14 @@ struct GraphormerLayer{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: Abstrac heads::Int concat::Bool add_self_loops::Bool - phi::Function + phi::Int + LN::LN + MHA::MHA end -@functor GraphormerLayer - -Flux.trainable(l::GraphormerLayer) = (l.dense_x, l.dense_e, l.bias, l.a) - function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; heads::Int = 1, concat::Bool = true, negative_slope = 0.2, - init = glorot_uniform, bias::Bool = true, add_self_loops = true) + init = glorot_uniform, bias::Bool = true, add_self_loops = true, phi::Int = 2) (in, ein), out = ch if add_self_loops @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." @@ -1737,13 +1756,12 @@ function GraphormerLayer(ch::Pair{NTuple{2, Int}, Int}, σ = identity; b = bias ? Flux.create_bias(dense_x.weight, true, concat ? out * heads : out) : false a = init(ein > 0 ? 3out : 2out, heads) negative_slope = convert(Float32, negative_slope) - GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, phi) + LN = LayerNorm(out) + MHA = MultiheadAttention(out, heads) + GraphormerLayer(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, phi, LN, MHA) end -(l::GraphormerLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) - -function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, - e::Union{Nothing, AbstractMatrix} = nothing) +function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" @@ -1769,6 +1787,7 @@ function (l::GraphormerLayer)(g::GNNGraph, x::AbstractMatrix, h = l.σ.(h) return h end + function message(l::GraphormerLayer, xi, xj, e) θ = cat(xi, xj, dims=2) if l.dense_e !== nothing @@ -1776,7 +1795,8 @@ function message(l::GraphormerLayer, xi, xj, e) fe = reshape(fe, size(fe, 1), l.heads, :) θ = cat(θ, fe, dims=1) end - W = l.a * θ + α = softmax(l.a' * θ) ## Not sure if I can use it from flux directly + W = α .* θ W = l.σ.(W) return W -end \ No newline at end of file +end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index f4f2a9282..5631b1383 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -333,3 +333,41 @@ end outsize = (in_channel, g.num_nodes)) end end + +@testset "GraphormerLayer" begin + heads = 2 + concat = true + negative_slope = 0.2 + add_self_loops = true + phi = 2 + ch = (in_channel, 0) => out_channel + σ = relu + init = glorot_uniform + bias = true + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = bias, add_self_loops = add_self_loops, phi = phi) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g.num_nodes)) + end + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = bias, add_self_loops = false, phi = phi) + test_layer(l, g1, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g1.num_nodes)) + ein = 3 + ch = (in_channel, ein) => out_channel + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = bias, add_self_loops = add_self_loops, phi = phi) + g = GNNGraph(g1, edata = rand(T, ein, g1.num_edges)) + test_layer(l, g, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g.num_nodes)) + l = GraphormerLayer(ch, σ; heads = heads, concat = concat, + negative_slope = negative_slope, init = init, + bias = false, add_self_loops = add_self_loops, phi = phi) + test_layer(l, g, rtol = RTOL_HIGH, + outsize = (concat ? heads * out_channel : out_channel, g.num_nodes)) + @test length(Flux.params(l)) == (ein > 0 ? 7 : 6) +end