Skip to content

Commit 51722e6

Browse files
authored
[KERNELS] Fix bf16 x mxfp4 when EVEN_K is False (#7203)
1 parent 4d3c498 commit 51722e6

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class Case:
193193
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
194194
Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),
195195
Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
196+
Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
196197
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
197198
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
198199
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
@@ -243,6 +244,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
243244
pytest.skip("float16 x mx not supported with cuda capability >= 10")
244245
if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10:
245246
pytest.skip("float8 x mx not supported with cuda capability < 10")
247+
if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9:
248+
pytest.skip("Not enough memory on A100")
249+
246250
elif is_hip():
247251
if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4():
248252
pytest.skip("float8 x mx only supported on CDNA4")

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ def _matmul_ogs(
233233
mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
234234
else:
235235
mask_k = offs_k < k
236-
mask_k_w = offs_w_k < (tl.cdiv(k, W_K_DIVISOR) * W_K_MULTIPLIER)
236+
mask_k_w = offs_w_k < ((k // W_K_DIVISOR) * W_K_MULTIPLIER)
237237
if is_microscaled_format and SWIZZLE_MX_SCALE is None:
238-
mask_k_scale = offs_k_scale < tl.cdiv(k, MX_PACK_DIVISOR)
238+
mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
239239

240240
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
241241
w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)

0 commit comments

Comments
 (0)