Skip to content

Commit 37788c2

Browse files
dshi7meta-codesync[bot]
authored andcommitted
Recover pre-commit errors (#709)
Summary: The errors were introduced by [D87842249](https://www.internalfb.com/diff/D87842249). Need to check why the diff can bypass the errors. Or is it possible to set up pre-commit runs in OSS periodically? this sounds a perfect task for AI agents. Pull Request resolved: #709 Reviewed By: agron911 Differential Revision: D88117942 Pulled By: dshi7 fbshipit-source-id: 4753087e45bbaa500298358911599d59acb6d56a
1 parent a21d269 commit 37788c2

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

python/test/unit/language/test_tlx.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,7 +2311,7 @@ def test_stoch_round_partial_pack(dst_dtype, device):
23112311

23122312
@triton.jit
23132313
def stoch_round_partial_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_ROUNDED: tl.constexpr,
2314-
QUARTER_SIZE_ROUNDED: tl.constexpr):
2314+
QUARTER_SIZE_ROUNDED: tl.constexpr):
23152315
# Use power-of-2 size for arange (triton requirement), then mask to actual size
23162316
offsets_full = tl.arange(0, BLOCK_SIZE_ROUNDED)
23172317
mask = offsets_full < BLOCK_SIZE
@@ -2357,8 +2357,8 @@ def stoch_round_partial_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, BLOCK_SIZ
23572357

23582358

23592359
@pytest.mark.skipif(not is_blackwell(), reason="Need Blackwell")
2360-
@pytest.mark.parametrize("invalid_src, invalid_dst",
2361-
[("float16", "float8_e5m2"), ("bfloat16", "float16"), ("float32", "int32")])
2360+
@pytest.mark.parametrize("invalid_src, invalid_dst", [("float16", "float8_e5m2"), ("bfloat16", "float16"),
2361+
("float32", "int32")])
23622362
def test_stoch_round_invalid_dtypes(invalid_src, invalid_dst, device):
23632363
"""Test that invalid dtype combinations raise proper errors."""
23642364

@@ -2423,6 +2423,5 @@ def stoch_round_seed_kernel(x_ptr, y_ptr, seed, BLOCK_SIZE: tl.constexpr):
24232423

24242424
# Results should be different for at least some values
24252425
different_count = (b1.float() != b2.float()).sum().item()
2426-
assert different_count > SIZE * 0.1, (
2427-
f"Different seeds should produce different results, "
2428-
f"but only {different_count}/{SIZE} values differ")
2426+
assert different_count > SIZE * 0.1, (f"Different seeds should produce different results, "
2427+
f"but only {different_count}/{SIZE} values differ")

third_party/tlx/language/tlx/utility.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,16 @@ def stoch_round(
111111
Tensor with dtype dst_ty and shape matching src.
112112
"""
113113
capability = int(cuda_parse_arch(_semantic.builder.options.arch))
114-
assert capability >= 100, (
115-
f"stoch_round requires compute capability >= 100 (Blackwell GPU), "
116-
f"current capability: {capability}"
117-
)
114+
assert capability >= 100, (f"stoch_round requires compute capability >= 100 (Blackwell GPU), "
115+
f"current capability: {capability}")
118116
src_ty = src.type
119117
src_sca_ty = src_ty.scalar
120118

121-
assert src_sca_ty == tl.float32, (
122-
f"Stochastic rounding only supports fp32 source, got {src_sca_ty}. "
123-
f"Source must be float32."
124-
)
125-
assert dst_ty in [tl.float8e5, tl.float8e4nv, tl.float16, tl.bfloat16], (
126-
f"Stochastic rounding only supports fp8/fp16/bf16 destination, got {dst_ty}. "
127-
f"Supported types: float8e5 (fp8 E5M2), float8e4nv (fp8 E4M3FN), float16, bfloat16"
128-
)
119+
assert src_sca_ty == tl.float32, (f"Stochastic rounding only supports fp32 source, got {src_sca_ty}. "
120+
f"Source must be float32.")
121+
assert dst_ty in [tl.float8e5, tl.float8e4nv, tl.float16, tl.bfloat16
122+
], (f"Stochastic rounding only supports fp8/fp16/bf16 destination, got {dst_ty}. "
123+
f"Supported types: float8e5 (fp8 E5M2), float8e4nv (fp8 E4M3FN), float16, bfloat16")
129124

130125
# Verify rbits shape matches src shape
131126
rbits_ty = rand_bits.type

0 commit comments

Comments
 (0)