Commit f81125b
authored
[mxfp] handle w_scale w/o swizzle correctly (#8652)
In practice, we can't support w_scale with column-wise strided layout,
since we will divide the reduction dim by 32 then it needs to be a
multiple of 16 for TMA. So, we disable TMA (and persistent kernel) for
this case. Added a test case for this.
Before this PR the test case led to
```
E triton.compiler.errors.CompilationError: at 227:26:
E w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
E w_scales = unswizzle_mx_scale_bw(w_scales)
E else:
E w_scales = WMxScale.load([expt_id, off_k_mx, off_n])
E w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
E
E # --- update accumulator ---
E if is_w_microscaled:
E if SWAP_XW:
E acc = tl.dot_scaled(w.T, w_scales, w_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
E else:
E acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True)
E ^
E rhs_scale must be a tensor of shape [256, 4]. Got ['4', '256']
```
The way ``make_dense_tma`` was checking if it was called for scale was
also ambiguous. Previously, it assumed for ``StridedLayout`` it's not
scale which is wrong.
# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.
- [x] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- Select one of the following.
- [x] I have added tests.
- `/test` for `lit` tests
- `/unittest` for C++ tests
- `/python/test` for end-to-end tests
- [ ] This PR does not need a test because `FILL THIS IN`.
- Select one of the following.
- [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
and using the instructions it generates is not minimal.)1 parent e93fc76 commit f81125b
File tree
4 files changed
+29
-19
lines changed- python/triton_kernels
- tests
- triton_kernels
- tensor_details/layout_details
4 files changed
+29
-19
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
233 | 233 | | |
234 | 234 | | |
235 | 235 | | |
| 236 | + | |
236 | 237 | | |
237 | 238 | | |
238 | 239 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
416 | 416 | | |
417 | 417 | | |
418 | 418 | | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
419 | 425 | | |
420 | 426 | | |
421 | 427 | | |
| |||
526 | 532 | | |
527 | 533 | | |
528 | 534 | | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
529 | 539 | | |
530 | 540 | | |
531 | 541 | | |
532 | | - | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
533 | 547 | | |
| 548 | + | |
534 | 549 | | |
535 | 550 | | |
536 | | - | |
| 551 | + | |
537 | 552 | | |
538 | 553 | | |
539 | 554 | | |
| |||
546 | 561 | | |
547 | 562 | | |
548 | 563 | | |
549 | | - | |
550 | | - | |
551 | | - | |
552 | | - | |
553 | 564 | | |
554 | 565 | | |
555 | 566 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
50 | | - | |
| 50 | + | |
51 | 51 | | |
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
| 55 | + | |
| 56 | + | |
55 | 57 | | |
56 | 58 | | |
57 | 59 | | |
58 | | - | |
| 60 | + | |
59 | 61 | | |
60 | 62 | | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
66 | 66 | | |
67 | 67 | | |
68 | 68 | | |
69 | | - | |
| 69 | + | |
70 | 70 | | |
71 | | - | |
| 71 | + | |
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
| |||
Lines changed: 2 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
| 37 | + | |
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
| |||
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
49 | | - | |
50 | | - | |
51 | 49 | | |
52 | | - | |
| 50 | + | |
53 | 51 | | |
54 | 52 | | |
55 | 53 | | |
| |||
0 commit comments