Skip to content

Commit 835ab1d

Browse files
committed
replaced max with clamp
1 parent 0788dc9 commit 835ab1d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float =
4343
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
4444
with utils.fp32_matmul_precision("highest"):
4545
R = Q.T - Q
46-
R /= torch.max(norm_lower_bound_skew(R), eps)
46+
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
4747
RQ = R @ Q
4848
# trace of RQ is always positive,
4949
# since tr(RQ) = ⟨R, Q⟩_F = ⟨Q^T - Q, Q⟩_F = ||Q||_F^2 - ⟨Q, Q⟩_F = ||Q||_F^2 - tr(Q^T Q) ≥ 0

emerging_optimizers/psgd/psgd_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps:
8585
"""
8686

8787
# Compute scaling factor from the largest diagonal entry to prevent overflow/underflow
88-
scale = torch.max(A.diagonal().amax(), eps)
88+
scale = torch.clamp(A.diagonal().amax(), min=eps)
8989
A = A / scale
9090

9191
bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps)
@@ -113,7 +113,7 @@ def norm_lower_bound_skew(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps
113113
"""
114114

115115
# Compute scaling factor from the max absolute value to prevent overflow/underflow
116-
scale = torch.max(A.abs().amax(), eps)
116+
scale = torch.clamp(A.abs().amax(), min=eps)
117117
A = A / scale
118118

119119
bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps)

0 commit comments

Comments
 (0)