Skip to content

Commit 2172bb4

Browse files
[mxfp] fix bf16 x mxfp4 bug with SUBTILE_FACTOR > 1 (#8478)
``` pytest -rxs --pdb python/triton_kernels/tests/test_matmul.py::test_op ``` Before this PR ``` E triton.compiler.errors.CompilationError: at 337:23: E if is_out_microscaled: E MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MXFP_BLOCK_SIZE E N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE) E E for a_i in tl.static_range(len(accs)): E acc_tile = accs[a_i] E acc_tile *= x_scale * w_scale E E if SWAP_XW: E acc_tile = acc_tile.T E E acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None] E ^ E ValueError('Cannot make_shape_compatible: incompatible dimensions at index 0: 64 and 16') ``` <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent a4ab31d commit 2172bb4

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ class Case:
235235
# mx types:
236236
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
237237
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
238+
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True, epilogue_subtile=4),
238239
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1),
239240
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
240241
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
@@ -321,7 +322,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
321322
if is_cuda():
322323
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
323324
pytest.skip("Float8 not tested on A100")
324-
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
325+
if act_dtype_str == "float16" and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
325326
pytest.skip("float16 x mx not supported with cuda capability >= 10")
326327
if weight_dtype_str.startswith("mx"):
327328
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10:

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,24 @@ def _p_matmul_ogs(
403403
biases = (bias,)
404404

405405
if SUBTILE_FACTOR >= 2:
406-
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
406+
if SWAP_XW:
407+
acc = acc.reshape(2, BLOCK_N // 2, BLOCK_M).permute(1, 2, 0)
408+
else:
409+
acc = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1)
410+
acc0, acc1 = acc.split()
407411
accs = (acc0, acc1)
408412
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
409413
biases = (bias0, bias1)
410414

411415
if SUBTILE_FACTOR >= 4:
412-
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
413-
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
416+
if SWAP_XW:
417+
acc0 = acc0.reshape(2, BLOCK_N // 4, BLOCK_M).permute(1, 2, 0)
418+
acc1 = acc1.reshape(2, BLOCK_N // 4, BLOCK_M).permute(1, 2, 0)
419+
else:
420+
acc0 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1)
421+
acc1 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1)
422+
acc00, acc01 = acc0.split()
423+
acc10, acc11 = acc1.split()
414424
accs = (acc00, acc01, acc10, acc11)
415425
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
416426
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()

0 commit comments

Comments
 (0)