Skip to content

Commit c7607d2

Browse files
committed
Revert "[BACKEND] Improve detection of register to register conversion (#4991)"
This reverts commit 15c5e55.
1 parent 30f0d5d commit c7607d2

File tree

7 files changed

+24
-202
lines changed

7 files changed

+24
-202
lines changed

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
212212

213213
bool atomicNeedsSharedMemory(Value result);
214214

215-
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
215+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
216216

217217
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218218

include/triton/Tools/LinearLayout.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -679,13 +679,6 @@ class LinearLayout {
679679
// (i.e. every input bit affects the output).
680680
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
681681

682-
// Increase an input dimension without affecting the output dimension. The
683-
// added free variables are mapped to 0, ensuring that the new input
684-
// dimensions correspond directly to the existing output space. The function
685-
// errors out if `newInDimSize` is less than the current size or the new size
686-
// is not a power of 2.
687-
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
688-
689682
std::string toString() const;
690683

691684
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ bool supportMMA(Value value, int version) {
543543
(elemTy.isInteger(8) && version >= 2);
544544
}
545545

546-
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
546+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
547547
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
548548
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
549549
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
@@ -662,46 +662,8 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
662662
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
663663
if (!(srcLayout.has_value() && dstLayout.has_value()))
664664
return std::nullopt;
665-
StringAttr kRegister = StringAttr::get(ctx, "register");
666-
StringAttr kLane = StringAttr::get(ctx, "lane");
667-
StringAttr kWarp = StringAttr::get(ctx, "warp");
668-
StringAttr kBlock = StringAttr::get(ctx, "block");
669-
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
670-
auto numDstRegs = dstLayout->getInDimSize(kRegister);
671-
// The `invertAndCompose` function will generate a layout that is injective
672-
// by assigning new output dimensions to free variables. For instance,
673-
// consider a scenario where `srcLayout` has a free variable in the lane
674-
// dimension, while `dstLayout` has two free variables in the lane
675-
// dimension and also a larger number of registers.
676-
// The injective form of `srcLayout` will add only a single additional row
677-
// to the transformation matrix, whereas the injective form of `dstLayout`
678-
// will add two additional rows. This discrepancy causes misleading results
679-
// because the matrices end up with a different number of rows.
680-
//
681-
// Take `dstLayout ⋅ srcLayout^-1` as an example:
682-
//
683-
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
684-
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
685-
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
686-
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
687-
// 1] → [n + 2, n + 1]
688-
//
689-
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
690-
// variable in registers, and the `(n + 2)`-th row represents the free
691-
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
692-
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
693-
// in two layouts do not correspond to the same free variable.
694-
//
695-
// To address this issue, we pad the free variables in `srcLayout` and
696-
// `dstLayout` to ensure they have the same number of registers. This
697-
// guarantees that the resulting matrices have the same number of rows,
698-
// ensuring consistency in the composition process.
699-
auto numRegs = std::max(numSrcRegs, numDstRegs);
700-
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
701-
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
702665
// comp describes the layout function to create dst from src.
703-
LinearLayout comp =
704-
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
666+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
705667
// We try to quotient by the largest subspace first
706668
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
707669
for (auto dim : dims) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -288,71 +288,60 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
288288
return rewriter.notifyMatchFailure(
289289
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
290290
}
291-
LinearLayout srcLayout =
292-
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
293-
LinearLayout dstLayout =
294-
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
295-
296-
StringAttr kBlock = str_attr("block");
297-
StringAttr kWarp = str_attr("warp");
298-
StringAttr kLane = str_attr("lane");
299-
StringAttr kRegister = str_attr("register");
300291

301292
assert(to_vector(conversion->getInDimNames()) ==
302293
to_vector(conversion->getOutDimNames()));
303294
auto dims = conversion->getInDimNames();
304-
if (llvm::is_contained(dims, kBlock)) {
295+
if (llvm::is_contained(dims, str_attr("block"))) {
305296
// Case 1: Transfer between values in different CTAs.
306297
// This requires moving values through distributed shared memory.
307298
return rewriter.notifyMatchFailure(
308299
op, "NYI: Transfer between different CTAs");
309-
} else if (llvm::is_contained(dims, kWarp)) {
300+
} else if (llvm::is_contained(dims, str_attr("warp"))) {
310301
// Case 2: Transfer between values in the same CTA, in which case we move
311302
// values through shared memory.
303+
LinearLayout srcLayout =
304+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
305+
LinearLayout dstLayout =
306+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
312307
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
313-
} else if (llvm::is_contained(dims, kLane)) {
308+
} else if (llvm::is_contained(dims, str_attr("lane"))) {
314309
// Case 3. Transfer between values in the same warp, in which case we try
315310
// to move values using warp shuffles, though if the pattern is
316311
// complicated enough we may fall back to using shared memory
317312
// TODO(Keren): implement warp shuffle instead of using the general
318313
// approach that uses shared memory
314+
LinearLayout srcLayout =
315+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
316+
LinearLayout dstLayout =
317+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
319318
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
320-
} else if (llvm::is_contained(dims, kRegister) ||
321-
dstLayout.getInDimSize(kRegister) !=
322-
srcLayout.getInDimSize(kRegister)) {
319+
} else if (llvm::is_contained(dims, str_attr("register"))) {
323320
// Case 4. Transfer between values in the same thread, in which case we
324321
// simply reorder the elements of adaptor.getSrc().
325-
return transferWithinThread(
326-
op, dstLayout.getFreeVariableMasks()[kRegister],
327-
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
322+
return transferWithinThread(op, *conversion, adaptor, rewriter);
328323
} else {
329-
// Cast 5. The two layouts are equivalent. We should probably remove
330-
// these in RemoveLayoutConversion.
324+
// The two layouts are equivalent. We should probably remove these in
325+
// RemoveLayoutConversion.
331326
rewriter.replaceOp(op, adaptor.getSrc());
332327
return success();
333328
}
334329
}
335330

336331
LogicalResult
337-
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
338-
const LinearLayout &conversion, OpAdaptor adaptor,
332+
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
333+
OpAdaptor adaptor,
339334
ConversionPatternRewriter &rewriter) const {
340335
MLIRContext *ctx = op.getContext();
341336
auto loc = op.getLoc();
342337
StringAttr kRegister = str_attr("register");
343338
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
344339

345340
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
346-
SmallVector<Value> outVals(numRegs);
347-
for (int i = 0; i < outVals.size(); i++) {
348-
// Remove free masks from the register index
349-
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350-
// 0b00011. It means that register 7 (0b111) has the same value as
351-
// register 3 (0b011).
352-
auto idx = i & (~regMasks);
353-
auto srcIdx = conversion.hasInDim(kRegister)
354-
? conversion.apply({{kRegister, idx}}).begin()->second
355-
: idx;
341+
SmallVector<Value> outVals;
342+
outVals.resize(conversion.getInDimSize(kRegister));
343+
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
344+
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
356345
outVals[i] = inVals[srcIdx];
357346
}
358347
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,

lib/Tools/LinearLayout.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,21 +1016,6 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const {
10161016
return true;
10171017
}
10181018

1019-
LinearLayout LinearLayout::resize(StringAttr inDim,
1020-
int32_t newInDimSize) const {
1021-
BasesT bases = getBases();
1022-
assert(bases.contains(inDim) && "inDim not in layout");
1023-
assert(llvm::isPowerOf2_32(newInDimSize) &&
1024-
"newInDimSize must be a power of 2");
1025-
assert(newInDimSize >= getInDimSize(inDim) &&
1026-
"newInDimSize must be >= old size");
1027-
auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim);
1028-
for (int i = 0; i < numFreeVariables; i++) {
1029-
bases[inDim].push_back(std::vector<int32_t>(getNumOutDims(), 0));
1030-
}
1031-
return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames()));
1032-
}
1033-
10341019
std::string LinearLayout::toString() const {
10351020
// Start with a newline because we print out a bulleted list; it doesn't
10361021
// make sense for the first line of this list to be on the same line as

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -847,80 +847,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
847847

848848
// -----
849849

850-
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
851-
#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
852-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
853-
// CHECK-LABEL: convert_layout_mmav2_dot_reg
854-
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
855-
// CHECK-NOT: st.shared
856-
// CHECK-NOT: llvm.load
857-
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
858-
tt.return
859-
}
860-
}
861-
862-
// -----
863-
864-
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
865-
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
866-
867-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
868-
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
869-
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
870-
// CHECK-NOT: st.shared
871-
// CHECK-NOT: llvm.load
872-
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
873-
tt.return
874-
}
875-
}
876-
877-
// -----
878-
879-
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
880-
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
881-
882-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
883-
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
884-
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
885-
// CHECK-NOT: st.shared
886-
// CHECK-NOT: llvm.load
887-
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
888-
tt.return
889-
}
890-
}
891-
892-
// -----
893-
894-
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
895-
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
896-
897-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
898-
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
899-
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
900-
// CHECK-NOT: st.shared
901-
// CHECK-NOT: llvm.load
902-
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
903-
tt.return
904-
}
905-
}
906-
907-
// -----
908-
909-
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
910-
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
911-
912-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
913-
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
914-
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
915-
// CHECK-NOT: st.shared
916-
// CHECK-NOT: llvm.load
917-
%0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
918-
tt.return
919-
}
920-
}
921-
922-
// -----
923-
924850
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
925851
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
926852
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {

unittest/Tools/LinearLayoutTest.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -747,39 +747,6 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) {
747747
ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value());
748748
}
749749

750-
TEST_F(LinearLayoutTest, Resize) {
751-
auto init = LinearLayout(
752-
{
753-
{S("in0"), {{0, 1}, {0, 2}}},
754-
{S("in1"), {{1, 0}, {2, 0}}},
755-
{S("in2"), {}},
756-
},
757-
{S("dim0"), S("dim1")});
758-
EXPECT_EQ(init.resize(S("in0"), 8),
759-
LinearLayout(
760-
{
761-
{S("in0"), {{0, 1}, {0, 2}, {0, 0}}},
762-
{S("in1"), {{1, 0}, {2, 0}}},
763-
{S("in2"), {}},
764-
},
765-
{S("dim0"), S("dim1")}));
766-
EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout(
767-
{
768-
{S("in0"), {{0, 1}, {0, 2}}},
769-
{S("in1"), {{1, 0}, {2, 0}}},
770-
{S("in2"), {}},
771-
},
772-
{S("dim0"), S("dim1")}));
773-
EXPECT_EQ(init.resize(S("in1"), 8),
774-
LinearLayout(
775-
{
776-
{S("in0"), {{0, 1}, {0, 2}}},
777-
{S("in1"), {{1, 0}, {2, 0}, {0, 0}}},
778-
{S("in2"), {}},
779-
},
780-
{S("dim0"), S("dim1")}));
781-
}
782-
783750
} // anonymous namespace
784751
} // namespace mlir::triton
785752

0 commit comments

Comments
 (0)