diff --git a/torch_geometric_temporal/nn/attention/tsagcn.py b/torch_geometric_temporal/nn/attention/tsagcn.py index aa902e23..8faf018a 100644 --- a/torch_geometric_temporal/nn/attention/tsagcn.py +++ b/torch_geometric_temporal/nn/attention/tsagcn.py @@ -237,7 +237,7 @@ def _adaptive_forward(self, x, y): .contiguous() .view(N, V, self.inter_c * T) ) - A2 = self.conv_b[i](x).view(N, self.inter_c * T, V) + A2 = self.conv_b[i](x).reshape(N, self.inter_c * T, V) A1 = self.tan(torch.matmul(A1, A2) / A1.size(-1)) # N V V A1 = A[i] + A1 * self.alpha A2 = x.view(N, C * T, V)