Skip to content

Commit 2b9221e

Browse files
committed
update: power_iter
1 parent ceef415 commit 2b9221e

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,9 @@ def power_iter(mat_g: torch.Tensor, error_tolerance: float = 1e-6, num_iters: in
359359
"""
360360
v: torch.Tensor = 2.0 * torch.rand(list(mat_g.shape)[0], dtype=mat_g.dtype, device=mat_g.device) - 1
361361

362-
error: torch.Tensor = 1.0
362+
error: Union[torch.Tensor, float] = 1.0
363363
iters: int = 0
364-
singular_val: torch.Tensor = 0
364+
singular_val: Union[torch.Tensor, float] = 0.0
365365
while error > error_tolerance and iters < num_iters:
366366
v.div_(v.norm())
367367
mat_v = torch.mv(mat_g, v)
@@ -373,8 +373,6 @@ def power_iter(mat_g: torch.Tensor, error_tolerance: float = 1e-6, num_iters: in
373373
singular_val = s_v
374374
iters += 1
375375

376-
singular_val.div_(singular_val.norm())
377-
378376
return singular_val
379377

380378

0 commit comments

Comments
 (0)