Skip to content

Commit 48468af

Browse files
authored
[LAYOUTS] Enable Slice(Dot) LL conversion (#5400)
There's no reason to disable this one.
1 parent 137bc62 commit 48468af

File tree

3 files changed

+58
-14
lines changed

3 files changed

+58
-14
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -770,10 +770,7 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
770770
std::optional<LinearLayout> parentLL =
771771
triton::gpu::toLinearLayout(parentShape, getParent());
772772
if (!parentLL.has_value()) {
773-
if (mlir::isa<DotOperandEncodingAttr>(getParent()))
774-
return std::nullopt;
775-
llvm::report_fatal_error(
776-
"Failed to compute parent layout for slice layout.");
773+
return std::nullopt;
777774
}
778775

779776
// Remove dimension getDim() from the parent layout.

python/test/unit/language/test_core.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ def __str__(self):
162162
return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>"
163163

164164

165+
class SliceLayout:
166+
167+
def __init__(self, dim, parent):
168+
self.dim = dim
169+
self.parent = parent
170+
171+
def __str__(self):
172+
return f"#{GPU_DIALECT}.slice<{{dim = {self.dim}, parent = {self.parent}}}>"
173+
174+
165175
class BlockedLayout:
166176

167177
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order):
@@ -199,6 +209,8 @@ def is_layout_applicable(layout) -> bool:
199209
common_layouts = [BlockedLayout, SharedLayout]
200210
if layout in common_layouts:
201211
return True
212+
elif isinstance(layout, SliceLayout):
213+
return is_layout_applicable(layout.parent)
202214
elif is_cuda():
203215
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
204216
if not isinstance(mma_layout, MmaLayout):
@@ -2850,8 +2862,11 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
28502862
# TODO (lixun): Add MfmaLayout
28512863
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
28522864
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
2865+
MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]),
28532866
MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1],
2854-
instr_shape=[16, 8])
2867+
instr_shape=[16, 8]),
2868+
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
2869+
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2),
28552870
]
28562871

28572872

@@ -5281,17 +5296,12 @@ def kernel(Out):
52815296
# TODO: backend should be tested separately
52825297

52835298
layouts = [
5299+
BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
5300+
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
52845301
MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]),
52855302
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
52865303
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1),
5287-
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
5288-
BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
5289-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
5290-
BlockedLayout([1, 1], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
5291-
BlockedLayout([8, 1], [16, THREADS_PER_WARP // 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
5292-
BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
5293-
BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
5294-
BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
5304+
MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]),
52955305
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2),
52965306
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2),
52975307
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2),
@@ -5300,7 +5310,13 @@ def kernel(Out):
53005310
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8),
53015311
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8),
53025312
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8),
5303-
MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]),
5313+
SliceLayout(
5314+
dim=1,
5315+
parent=DotOperandLayout(parent=MmaLayout([3, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [16, 32, 16]),
5316+
op_idx=0, k_width=2)),
5317+
SliceLayout(
5318+
dim=1, parent=DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]),
5319+
op_idx=1, k_width=2)),
53045320
]
53055321

53065322
intermediate_layouts = [

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,37 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) {
733733
{S("dim0"), S("dim1")}));
734734
}
735735

736+
TEST_F(LinearLayoutConversionsTest, SliceDot) {
737+
// Slice layout with a DotOperand (MMAv2) as the parent.
738+
auto parentV2 = dot(mma(2, 0, {16, 8}, {1, 1}), /*opIdx=*/0, /*kWidth=*/8);
739+
auto sliceV2 = slice(parentV2, /*dim=*/1);
740+
741+
EXPECT_EQ(toLinearLayout({16}, sliceV2),
742+
LinearLayout(
743+
{
744+
{S("register"), {{8}}},
745+
{S("lane"), {{0}, {0}, {1}, {2}, {4}}},
746+
{S("warp"), {}},
747+
{S("block"), {}},
748+
},
749+
{S("dim0")}));
750+
751+
// Slice layout with a DotOperand (MMAv3) as the parent.
752+
auto parentV3 =
753+
dot(mma(3, 0, {16, 16, 8}, {4, 1}), /*opIdx=*/0, /*kWidth=*/2);
754+
auto sliceV3 = slice(parentV3, /*dim=*/0);
755+
756+
EXPECT_EQ(toLinearLayout({16}, sliceV3),
757+
LinearLayout(
758+
{
759+
{S("register"), {{1}, {8}}},
760+
{S("lane"), {{2}, {4}, {0}, {0}, {0}}},
761+
{S("warp"), {{0}, {0}}},
762+
{S("block"), {}},
763+
},
764+
{S("dim0")}));
765+
}
766+
736767
TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
737768
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
738769
/*isTransposed=*/false);

0 commit comments

Comments
 (0)