Skip to content

Commit 5d84a91

Browse files
[mxfp] fix x_scale OOB (#8369)
``` pytest python/triton_kernels/tests/test_matmul.py::"test_op[True-True-True-True-False-None-16-300-400-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-4-1-True-None-False-False-False]" ``` <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 59aeb6b commit 5d84a91

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def _p_matmul_ogs(
311311
mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
312312
else:
313313
mask_k_scale = off_k_mx + tl.arange(0, MX_SCALE_BLOCK_K) < tl.cdiv(K, MX_PACK_DIVISOR)
314-
x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :], other=0.0)
314+
mask_m = off_m + tl.arange(0, BLOCK_M) < eM
315+
x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :] & mask_m[:, None], other=0.0)
315316
elif x_format == "fp16" or x_format == "bf16":
316317
x_scales: tl.constexpr = None
317318
else:

0 commit comments

Comments
 (0)