diff --git a/torch_geometric/nn/conv/transformer_conv.py b/torch_geometric/nn/conv/transformer_conv.py index b3467774992f..9e046d604881 100644 --- a/torch_geometric/nn/conv/transformer_conv.py +++ b/torch_geometric/nn/conv/transformer_conv.py @@ -135,20 +135,24 @@ def __init__( self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False) else: self.lin_edge = self.register_parameter('lin_edge', None) - - if concat: - self.lin_skip = Linear(in_channels[1], heads * out_channels, - bias=bias) - if self.beta: - self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) + if root_weight: + if concat: + self.lin_skip = Linear(in_channels[1], heads * out_channels, + bias=bias) + if self.beta: + self.lin_beta = Linear(3 * heads * out_channels, 1, + bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) else: - self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) + if self.beta: + self.lin_beta = Linear(3 * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) else: - self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) - if self.beta: - self.lin_beta = Linear(3 * out_channels, 1, bias=False) - else: - self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_skip = self.register_parameter('lin_skip', None) + self.lin_beta = self.register_parameter('lin_beta', None) self.reset_parameters() @@ -159,7 +163,8 @@ def reset_parameters(self): self.lin_value.reset_parameters() if self.edge_dim: self.lin_edge.reset_parameters() - self.lin_skip.reset_parameters() + if self.root_weight: + self.lin_skip.reset_parameters() if self.beta: self.lin_beta.reset_parameters()