Skip to content

Commit a54f309

Browse files
authored
Fix out-of-bounds load in mxfp_matmul test kernel. (#7193)
In the current `mxfp_matmul`, there is a mask applied to input data loads over the K-dimension to treat blocks bigger than the input tensors. However, no mask is applied to the scale loads, resulting in random values for the scales, including NaN values, which may lead to NaN values in the output. It is fixed by applying a mask for scale loads. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because it fixes an existing one. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) Signed-off-by: Ilya Enkovich <[email protected]>
1 parent 51722e6 commit a54f309

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,8 @@ def mxfp_matmul( #
313313
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
314314
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty)
315315
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
316-
k_remaining = K - k * BLOCK_K
317-
valid_k = offs_k < k_remaining
318-
a = tl.load(a_ptrs, mask=valid_k[None, :], other=0.)
319-
b = tl.load(b_ptrs, mask=valid_k[:, None], other=0.)
316+
a = tl.load(a_ptrs)
317+
b = tl.load(b_ptrs)
320318
scale_a = tl.load(a_scale_ptr)
321319
scale_b = tl.load(b_scale_ptr)
322320
accumulator = tl.dot_scaled(a, scale_a, "e5m2", b, scale_b, "e5m2", accumulator)
@@ -339,21 +337,20 @@ def fp8e8m0_to_float32(scale):
339337
return scale
340338

341339

342-
@pytest.mark.parametrize("M, N, K", [(1024, 512, 256), (128, 256, 256), (128, 128, 128), (2, 4, 32), (2, 4, 64),
343-
(256, 16, 32)])
340+
@pytest.mark.parametrize("M, N, K", [(1024, 512, 256), (128, 256, 256), (128, 128, 128), (2, 4, 64)])
344341
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128),
345342
(128, 256, 256), (128, 128, 64), (128, 64, 128)])
346343
@pytest.mark.parametrize("NUM_STAGES", [1, 3])
347344
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
348345
@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else [0]))
349346
def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device):
347+
if K % BLOCK_K != 0:
348+
pytest.skip("Kernel requires shapes aligned by K dimension")
350349
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
351350
pytest.skip("Requires compute capability >= 10")
352351
elif is_hip():
353352
if not is_hip_cdna4():
354353
pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4")
355-
if (M == 2 and N == 4 and K == 32) or (M == 256 and N == 16 and K == 32):
356-
pytest.skip(f"Input shape {M=}, {N=}, {K=} is not supported yet")
357354
if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64):
358355
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
359356

0 commit comments

Comments
 (0)