Skip to content

Commit 373a545

Browse files
committed
made debug statements in logging
Signed-off-by: mikail <mkhona@nvidia.com>
1 parent 4168b01 commit 373a545

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

tests/test_spectral_clip_utils.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@ def test_spectral_clipping(self, dims, sigma_range):
5252
min_sv = singular_values.min().item()
5353
max_sv = singular_values.max().item()
5454

55-
logging.info(f"Original matrix shape: {x.shape}")
56-
logging.info(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")
57-
logging.info(f"Clipped singular values range: [{min_sv:.6f}, {max_sv:.6f}]")
58-
logging.info(f"Target range: [{sigma_min:.6f}, {sigma_max:.6f}]")
59-
logging.info(f"Shape preservation: input {x.shape} -> output {clipped_x.shape}")
60-
55+
logging.debug(f"Original matrix shape: {x.shape}")
56+
logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")
57+
logging.debug(f"Clipped singular values range: [{min_sv:.6f}, {max_sv:.6f}]")
58+
logging.debug(f"Target range: [{sigma_min:.6f}, {sigma_max:.6f}]")
59+
logging.debug(f"Shape preservation: input {x.shape} -> output {clipped_x.shape}")
60+
61+
# use higher tolerance for lower singular values
62+
# typically, this algorithm introduces more error for lower singular values
6163
tolerance_upper = 1e-1
6264
tolerance_lower = 5e-1
6365
self.assertGreaterEqual(
@@ -82,8 +84,8 @@ def test_spectral_hardcap(self, dims, beta):
8284
U_orig, original_singular_values, Vt_orig = torch.linalg.svd(x, full_matrices=False)
8385
original_min_sv = original_singular_values.min().item()
8486
original_max_sv = original_singular_values.max().item()
85-
logging.info(f"Original matrix shape: {x.shape}")
86-
logging.info(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")
87+
logging.debug(f"Original matrix shape: {x.shape}")
88+
logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")
8789

8890
hardcapped_x = orthogonalized_optimizers.spectral_hardcap(x, beta=beta)
8991

@@ -93,9 +95,9 @@ def test_spectral_hardcap(self, dims, beta):
9395

9496
max_sv = singular_values.max().item()
9597

96-
logging.info(f"Hardcapped max singular value: {max_sv:.6f}")
97-
logging.info(f"Beta (upper bound): {beta:.6f}")
98-
logging.info(f"Shape preservation: input {x.shape} -> output {hardcapped_x.shape}")
98+
logging.debug(f"Hardcapped max singular value: {max_sv:.6f}")
99+
logging.debug(f"Beta (upper bound): {beta:.6f}")
100+
logging.debug(f"Shape preservation: input {x.shape} -> output {hardcapped_x.shape}")
99101

100102
self.assertLessEqual(
101103
max_sv - tolerance_upper,
@@ -112,7 +114,7 @@ def test_spectral_hardcap(self, dims, beta):
112114
relative_polar_frobenius_diff = torch.norm(polar_orig - polar_hard, "fro") / torch.norm(polar_orig, "fro")
113115
polar_tolerance = 1e-4
114116

115-
logging.info(f"Polar factor Frobenius norm difference: {relative_polar_frobenius_diff:.6f}")
117+
logging.debug(f"Polar factor Frobenius norm difference: {relative_polar_frobenius_diff:.6f}")
116118

117119
self.assertLessEqual(
118120
relative_polar_frobenius_diff,

0 commit comments

Comments
 (0)