We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f29a9df commit 1d9dfb0Copy full SHA for 1d9dfb0
pytorch_optimizer/optimizer/shampoo_utils.py
@@ -270,7 +270,7 @@ def __init__(
270
shapes: List[Optional[List[torch.Tensor]]] = self.partitioner.shapes_for_pre_conditioners()
271
self.statistics = [self.matrix_eps * torch.eye(shape[0], device=var.device) for shape in shapes if shape]
272
self.pre_conditioners = [torch.eye(shape[0], device=var.device) for shape in shapes if shape]
273
- self.is_same_shapes = None not in shapes and len(torch.unique(shapes)) == 1
+ self.is_same_shapes = None not in shapes and len(np.unique(shapes)) == 1
274
275
if self.is_same_shapes:
276
self.statistics = torch.stack(self.statistics, dim=0)
0 commit comments