Skip to content

Commit c8f80ae

Browse files
committed
update sm version guard
Signed-off-by: Hao Wu <[email protected]>
1 parent d6c89c9 commit c8f80ae

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

emerging_optimizers/orthogonalized_optimizers/muon_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,12 @@ def newton_schulz(
146146
X = X.to(torch.bfloat16)
147147
logging.log_first_n(logging.INFO, "Using BF16 I/O kernels for Newton-Schulz iteration.", 1)
148148
if use_syrk:
149-
ns_step_fn = newton_schulz_step_tsyrk
149+
sm_version = torch.cuda.get_device_capability()
150+
if sm_version in ((8, 0), (9, 0), (10, 0), (11, 0)):
151+
logging.log_first_n(
152+
logging.INFO, f"Using Triton SYRK kernels for Newton-Schulz iteration on SM {sm_version}.", 1
153+
)
154+
ns_step_fn = newton_schulz_step_tsyrk
150155

151156
for i in range(steps):
152157
a, b, c = coefficient_sets[i % len(coefficient_sets)]

emerging_optimizers/triton_kernels/syrk.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,6 @@ def tsyrk_ex(
315315
Returns:
316316
Output tensor of shape (N, N)
317317
"""
318-
sm_version = torch.cuda.get_device_capability()
319-
assert sm_version in ((8, 0), (9, 0), (10, 0), (11, 0)), (
320-
f"Correctness of Triton kernel on SM {sm_version} can not be guaranteed."
321-
)
322318
assert a.dtype == torch.bfloat16, "Input tensor must be bfloat16"
323319
assert a.dim() == 2, "Input tensor must be 2D"
324320
assert a.is_contiguous() or a.T.is_contiguous(), "invalid input tensor layout. a or a.T must be contiguous."

tests/test_muon_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from emerging_optimizers.orthogonalized_optimizers import muon, muon_utils
2323

2424

25+
_SM_VERSION = torch.cuda.get_device_capability() if torch.cuda.is_available() else None
26+
27+
2528
def newton_schulz_ref(x: torch.Tensor, coefficient_sets: list[tuple[float, float, float]]) -> torch.Tensor:
2629
"""Reference Newton-Schulz iteration to compute the zeroth power / orthogonalization of x."""
2730
# Muon is not for 1d parameters
@@ -208,14 +211,11 @@ def test_qkv_split_shapes_validation(self):
208211
self.assertIn("tuple of 3 integers", str(cm.exception))
209212

210213

214+
@absltest.skipIf(
215+
_SM_VERSION is None or _SM_VERSION not in ((8, 0), (9, 0), (10, 0), (11, 0)),
216+
f"Correctness of Triton kernel on SM {_SM_VERSION} cannot be guaranteed.",
217+
)
211218
class TestNewtonSchulzStepWithTsyrk(parameterized.TestCase):
212-
def setUp(self):
213-
self.prev_precision = torch.get_float32_matmul_precision()
214-
torch.set_float32_matmul_precision("highest")
215-
216-
def tearDown(self):
217-
torch.set_float32_matmul_precision(self.prev_precision)
218-
219219
@parameterized.parameters(
220220
(32, 32),
221221
(32, 64),

0 commit comments

Comments
 (0)