Skip to content

Commit f29a9df

Browse files
committed
refactor: pre-conditioner
1 parent a770ec7 commit f29a9df

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(
270270
shapes: List[Optional[List[torch.Tensor]]] = self.partitioner.shapes_for_pre_conditioners()
271271
self.statistics = [self.matrix_eps * torch.eye(shape[0], device=var.device) for shape in shapes if shape]
272272
self.pre_conditioners = [torch.eye(shape[0], device=var.device) for shape in shapes if shape]
273-
self.is_same_shapes = None not in shapes and len(np.unique(shapes)) == 1
273+
self.is_same_shapes = None not in shapes and len(torch.unique(shapes)) == 1
274274

275275
if self.is_same_shapes:
276276
self.statistics = torch.stack(self.statistics, dim=0)
@@ -302,8 +302,7 @@ def add_statistics(self, grad: torch.Tensor) -> None:
302302
reshaped_grad: torch.Tensor = torch.reshape(grad, self.transformed_shape)
303303
partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad)
304304

305-
for j in range(len(partitioned_grads)):
306-
partitioned_grad: torch.Tensor = partitioned_grads[j]
305+
for j, partitioned_grad in enumerate(partitioned_grads):
307306
for i in range(self.rank):
308307
axes: List[int] = [ax for ax in range(partitioned_grad.ndim) if ax != i]
309308
stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, dims=[axes, axes])
@@ -341,7 +340,7 @@ def precondition_block(
341340
We keep all axes in the same cyclic order they were originally.
342341
"""
343342
rank: int = len(partitioned_grad.shape)
344-
roll: Tuple[int, ...] = (*tuple(range(1, rank)), 0)
343+
roll: Tuple[int, ...] = (*range(1, rank), 0)
345344

346345
i: int = 0
347346
for should_precondition_dim in should_preconditioned_dims:
@@ -376,7 +375,7 @@ def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
376375

377376
merged_grad = self.partitioner.merge_partitions(pre_cond_partitioned_grads)
378377

379-
return torch.reshape(merged_grad, self.original_shape)
378+
return merged_grad.reshape(self.original_shape)
380379

381380

382381
def build_graft(p: torch.Tensor, graft_type: int, diagonal_eps: float = 1e-10):

0 commit comments

Comments
 (0)