Skip to content

Commit 1c3c3c3

Browse files
committed
Revert "Revert "[BACKEND] Improve detection of register to register conversion (#4991)""
This reverts commit c7607d2.
1 parent 35dde63 commit 1c3c3c3

File tree

7 files changed

+172
-19
lines changed

7 files changed

+172
-19
lines changed

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
212212

213213
bool atomicNeedsSharedMemory(Value result);
214214

215-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
215+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
216216

217217
// Return true if the src and dst layout match.
218218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,

include/triton/Tools/LinearLayout.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ class LinearLayout {
679679
// (i.e. every input bit affects the output).
680680
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
681681

682+
// Increase an input dimension without affecting the output dimension. The
683+
// added free variables are mapped to 0, ensuring that the new input
684+
// dimensions correspond directly to the existing output space. The function
685+
// errors out if `newInDimSize` is less than the current size or the new size
686+
// is not a power of 2.
687+
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
688+
682689
std::string toString() const;
683690

684691
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/Utility.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ bool supportMMA(Value value, int version) {
543543
(elemTy.isInteger(8) && version >= 2);
544544
}
545545

546-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
546+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
547547
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
548548
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
549549
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
@@ -646,8 +646,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
646646
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
647647
if (!(srcLayout.has_value() && dstLayout.has_value()))
648648
return std::nullopt;
649+
StringAttr kRegister = StringAttr::get(ctx, "register");
650+
StringAttr kLane = StringAttr::get(ctx, "lane");
651+
StringAttr kWarp = StringAttr::get(ctx, "warp");
652+
StringAttr kBlock = StringAttr::get(ctx, "block");
653+
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
654+
auto numDstRegs = dstLayout->getInDimSize(kRegister);
655+
// The `invertAndCompose` function will generate a layout that is injective
656+
// by assigning new output dimensions to free variables. For instance,
657+
// consider a scenario where `srcLayout` has a free variable in the lane
658+
// dimension, while `dstLayout` has two free variables in the lane
659+
// dimension and also a larger number of registers.
660+
// The injective form of `srcLayout` will add only a single additional row
661+
// to the transformation matrix, whereas the injective form of `dstLayout`
662+
// will add two additional rows. This discrepancy causes misleading results
663+
// because the matrices end up with a different number of rows.
664+
//
665+
// Take `dstLayout ⋅ srcLayout^-1` as an example:
666+
//
667+
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
668+
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
669+
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
670+
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
671+
// 1] → [n + 2, n + 1]
672+
//
673+
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
674+
// variable in registers, and the `(n + 2)`-th row represents the free
675+
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
676+
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
677+
// in two layouts do not correspond to the same free variable.
678+
//
679+
// To address this issue, we pad the free variables in `srcLayout` and
680+
// `dstLayout` to ensure they have the same number of registers. This
681+
// guarantees that the resulting matrices have the same number of rows,
682+
// ensuring consistency in the composition process.
683+
auto numRegs = std::max(numSrcRegs, numDstRegs);
684+
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
685+
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
649686
// comp describes the layout function to create dst from src.
650-
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
687+
LinearLayout comp =
688+
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
651689
// We try to quotient by the largest subspace first
652690
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
653691
for (auto dim : dims) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328328
} else {
329329
// Cast 5. The two layouts are equivalent. We should probably remove
330330
// these in RemoveLayoutConversion.
331-
auto dstCvt = requiresI32Conversion(dstTy);
332-
auto srcCvt = requiresI32Conversion(srcTy);
333-
if (dstCvt || srcCvt) {
334-
auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter);
335-
inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(),
336-
getTypeConverter());
337-
inVals =
338-
packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter());
339-
auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals,
340-
rewriter, op.getType());
341-
rewriter.replaceOp(op, res);
342-
} else {
343-
rewriter.replaceOp(op, adaptor.getSrc());
344-
}
331+
rewriter.replaceOp(op, adaptor.getSrc());
345332
return success();
346333
}
347334
}
@@ -358,9 +345,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
358345
auto srcTy = op.getSrc().getType();
359346
auto dstTy = op.getType();
360347
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
361-
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
362348
SmallVector<Value> outVals(numRegs);
363-
for (int i = 0; i < numRegs; i++) {
349+
for (int i = 0; i < outVals.size(); i++) {
364350
// Remove free masks from the register index
365351
// For example, if idx = 0b00111, and masks = 0b00100, then we get
366352
// 0b00011. It means that register 7 (0b111) has the same value as

lib/Tools/LinearLayout.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const {
10161016
return true;
10171017
}
10181018

1019+
LinearLayout LinearLayout::resize(StringAttr inDim,
1020+
int32_t newInDimSize) const {
1021+
BasesT bases = getBases();
1022+
assert(bases.contains(inDim) && "inDim not in layout");
1023+
assert(llvm::isPowerOf2_32(newInDimSize) &&
1024+
"newInDimSize must be a power of 2");
1025+
assert(newInDimSize >= getInDimSize(inDim) &&
1026+
"newInDimSize must be >= old size");
1027+
auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim);
1028+
for (int i = 0; i < numFreeVariables; i++) {
1029+
bases[inDim].push_back(std::vector<int32_t>(getNumOutDims(), 0));
1030+
}
1031+
return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames()));
1032+
}
1033+
10191034
std::string LinearLayout::toString() const {
10201035
// Start with a newline because we print out a bulleted list; it doesn't
10211036
// make sense for the first line of this list to be on the same line as

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
946946

947947
// -----
948948

949+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
950+
#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
951+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
952+
// CHECK-LABEL: convert_layout_mmav2_dot_reg
953+
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
954+
// CHECK-NOT: st.shared
955+
// CHECK-NOT: llvm.load
956+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
957+
tt.return
958+
}
959+
}
960+
961+
// -----
962+
963+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
964+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
965+
966+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
967+
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
968+
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
969+
// CHECK-NOT: st.shared
970+
// CHECK-NOT: llvm.load
971+
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
972+
tt.return
973+
}
974+
}
975+
976+
// -----
977+
978+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
979+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
980+
981+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
982+
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
983+
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
984+
// CHECK-NOT: st.shared
985+
// CHECK-NOT: llvm.load
986+
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
987+
tt.return
988+
}
989+
}
990+
991+
// -----
992+
993+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
994+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
995+
996+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
997+
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
998+
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
999+
// CHECK-NOT: st.shared
1000+
// CHECK-NOT: llvm.load
1001+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
1002+
tt.return
1003+
}
1004+
}
1005+
1006+
// -----
1007+
1008+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
1009+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
1010+
1011+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1012+
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
1013+
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
1014+
// CHECK-NOT: st.shared
1015+
// CHECK-NOT: llvm.load
1016+
%0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
1017+
tt.return
1018+
}
1019+
}
1020+
1021+
// -----
1022+
9491023
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
9501024
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
9511025
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {

unittest/Tools/LinearLayoutTest.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,39 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) {
747747
ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value());
748748
}
749749

750+
TEST_F(LinearLayoutTest, Resize) {
751+
auto init = LinearLayout(
752+
{
753+
{S("in0"), {{0, 1}, {0, 2}}},
754+
{S("in1"), {{1, 0}, {2, 0}}},
755+
{S("in2"), {}},
756+
},
757+
{S("dim0"), S("dim1")});
758+
EXPECT_EQ(init.resize(S("in0"), 8),
759+
LinearLayout(
760+
{
761+
{S("in0"), {{0, 1}, {0, 2}, {0, 0}}},
762+
{S("in1"), {{1, 0}, {2, 0}}},
763+
{S("in2"), {}},
764+
},
765+
{S("dim0"), S("dim1")}));
766+
EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout(
767+
{
768+
{S("in0"), {{0, 1}, {0, 2}}},
769+
{S("in1"), {{1, 0}, {2, 0}}},
770+
{S("in2"), {}},
771+
},
772+
{S("dim0"), S("dim1")}));
773+
EXPECT_EQ(init.resize(S("in1"), 8),
774+
LinearLayout(
775+
{
776+
{S("in0"), {{0, 1}, {0, 2}}},
777+
{S("in1"), {{1, 0}, {2, 0}, {0, 0}}},
778+
{S("in2"), {}},
779+
},
780+
{S("dim0"), S("dim1")}));
781+
}
782+
750783
} // anonymous namespace
751784
} // namespace mlir::triton
752785

0 commit comments

Comments
 (0)