Skip to content

Commit 24282a9

Browse files
committed
update: compute_power_schur_newton
1 parent 2b9221e commit 24282a9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def compute_power_schur_newton(
442442
"""
443443
shape: List[int] = list(mat_g.shape)
444444
if len(shape) == 1:
445-
return torch.pow(mat_g + ridge_epsilon, -1 / p)
445+
return torch.pow(mat_g + ridge_epsilon, -1.0 / p)
446446

447447
identity = torch.eye(shape[0], device=mat_g.device, dtype=torch.float32)
448448
if shape[0] == 1:
@@ -458,10 +458,11 @@ def compute_power_schur_newton(
458458
mat_m = mat_g * z
459459

460460
alpha: float = -1.0 / p
461+
alpha_identity: torch.Tensor = (1.0 - alpha) * identity
461462
error = torch.max(torch.abs(mat_m - identity))
462463
count: int = 0
463464
while error > error_tolerance and count < iter_count:
464-
mat_m_i = (1 - alpha) * identity + alpha * mat_m
465+
mat_m_i = alpha_identity + alpha * mat_m
465466
new_mat_root = torch.matmul(mat_root, mat_m_i).float()
466467
mat_m = torch.matmul(matrix_power(mat_m_i, p), mat_m).float()
467468

0 commit comments

Comments
 (0)