From 3fcb2fa3265858de951cec985fa4a25f79c5651e Mon Sep 17 00:00:00 2001 From: Mateus Israel Silva <38586542+Mateusrael@users.noreply.github.com> Date: Wed, 26 Nov 2025 10:35:46 -0300 Subject: [PATCH 1/2] Fix layer creation of lin_skip on transformer_conv.py Doesn't create lin_skip when root_weight is False --- torch_geometric/nn/conv/transformer_conv.py | 30 ++++++++++++--------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/torch_geometric/nn/conv/transformer_conv.py b/torch_geometric/nn/conv/transformer_conv.py index b3467774992f..06a4e9b0b5c3 100644 --- a/torch_geometric/nn/conv/transformer_conv.py +++ b/torch_geometric/nn/conv/transformer_conv.py @@ -135,20 +135,23 @@ def __init__( self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False) else: self.lin_edge = self.register_parameter('lin_edge', None) - - if concat: - self.lin_skip = Linear(in_channels[1], heads * out_channels, - bias=bias) - if self.beta: - self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) + if root_weight: + if concat: + self.lin_skip = Linear(in_channels[1], heads * out_channels, + bias=bias) + if self.beta: + self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) else: - self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) + if self.beta: + self.lin_beta = Linear(3 * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) else: - self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) - if self.beta: - self.lin_beta = Linear(3 * out_channels, 1, bias=False) - else: - self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_skip = self.register_parameter('lin_skip', None) + self.lin_beta = self.register_parameter('lin_beta', None) self.reset_parameters() @@ -159,7 +162,8 @@ def reset_parameters(self): self.lin_value.reset_parameters() if self.edge_dim: self.lin_edge.reset_parameters() - self.lin_skip.reset_parameters() + if self.root_weight: + self.lin_skip.reset_parameters() if self.beta: self.lin_beta.reset_parameters() From e69943160f83295fe781893bddc2e1a535952b52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:41:28 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/conv/transformer_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/transformer_conv.py b/torch_geometric/nn/conv/transformer_conv.py index 06a4e9b0b5c3..9e046d604881 100644 --- a/torch_geometric/nn/conv/transformer_conv.py +++ b/torch_geometric/nn/conv/transformer_conv.py @@ -140,7 +140,8 @@ def __init__( self.lin_skip = Linear(in_channels[1], heads * out_channels, bias=bias) if self.beta: - self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) + self.lin_beta = Linear(3 * heads * out_channels, 1, + bias=False) else: self.lin_beta = self.register_parameter('lin_beta', None) else: