Skip to content

Commit a0af041

Browse files
committed
fix: add_statistics when PreConditionerType is 1 (INPUT)
1 parent abcf6f0 commit a0af041

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ def add_statistics(self, grad: torch.Tensor):
248248
partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad)
249249

250250
w2: float = 1.0 if self.beta2 == 1.0 else (1.0 - self.beta2)
251-
rank: int = len(self.transformed_shape)
251+
rank: int = sum(self.should_precondition_dims())
252252
for j, partitioned_grad in enumerate(partitioned_grads):
253253
for i in range(rank):
254-
axes: List[int] = list(range(i)) + list(range(i + 1, rank))
254+
axes: List[int] = [ax for ax in range(partitioned_grad.ndim) if ax != i]
255255
stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, [axes, axes])
256256
self.statistics[j * rank + i].mul_(self.beta2).add_(stat, alpha=w2)
257257

0 commit comments

Comments
 (0)