Skip to content

[Bug] SumBatchLinearOperator fails for high-order tensor #100

@lmao14

Description

@lmao14

🐛 Bug

To reproduce

** Code snippet to reproduce **

import torch
import gpytorch
import linear_operator

kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([4, 3]),),
                                        batch_shape=torch.Size([4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)
print(kxx.shape)
print(kxx.to_dense().sum(0).shape)
print(kxx.sum(0).to_dense().shape)

torch.Size([4, 3, 2, 2])
torch.Size([3, 2, 2])
torch.Size([4, 5, 5])

kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([5, 4, 3]),),
                                        batch_shape=torch.Size([5, 4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)
print(kxx.sum(0).to_dense().shape)

** Stack trace/error message **

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[65], line 5
      3 X = torch.randn([2, 5])
      4 kxx = kern(X)
----> 5 print(kxx.sum(0).to_dense().shape)

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:2517, in LinearOperator.sum(self, dim)
   2515 # Otherwise: it's a batch dimension
   2516 elif dim < self.dim():
-> 2517     return self._sum_batch(dim)
   2518 else:
   2519     raise ValueError("Invalid dim ({}) for LinearOperator of size {}".format(orig_dim, self.shape))

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:861, in LinearOperator._sum_batch(self, dim)
    850 """
    851 Sum the LinearOperator across a batch dimension (supplied as a positive number).
    852 
   (...)
    857 :param dim: The (positive valued) dimension to sum
    858 """
    859 from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator
--> 861 return SumBatchLinearOperator(self, block_dim=dim)

File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/block_linear_operator.py:50, in BlockLinearOperator.__init__(self, base_linear_op, block_dim)
     48 if block_dim != -3:
     49     positive_block_dim = base_linear_op.dim() + block_dim
---> 50     base_linear_op = base_linear_op._permute_batch(
     51         *range(positive_block_dim),
     52         *range(positive_block_dim + 1, base_linear_op.dim() - 2),
     53         positive_block_dim,
     54     )
     55 super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op))
     56 self.base_linear_op = base_linear_op

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:248, in LinearOperator._permute_batch(self, *dims)
    246 if torch.is_tensor(component):
    247     extra_dims = range(len(dims), component.dim())
--> 248     components.append(component.permute(*dims, *extra_dims))
    249 elif isinstance(component, LinearOperator):
    250     components.append(component._permute_batch(*dims))

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3

System information

Please complete the following information:

  • LinearOperator Version 0.5.3
  • PyTorch Version 2.0.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions