Skip to content

Commit 63807b5

Browse files
quintinwang5whitneywhtsang
authored andcommitted
Revert "Revert "[LAYOUTS] Use least squares solution in invertAndCompose (#5309)""
This reverts commit 7d0818a.
1 parent dac237f commit 63807b5

File tree

7 files changed

+175
-267
lines changed

7 files changed

+175
-267
lines changed

include/triton/Tools/LinearLayout.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -687,13 +687,6 @@ class LinearLayout {
687687
// (i.e. every input bit affects the output).
688688
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
689689

690-
// Increase an input dimension without affecting the output dimension. The
691-
// added free variables are mapped to 0, ensuring that the new input
692-
// dimensions correspond directly to the existing output space. The function
693-
// errors out if `newInDimSize` is less than the current size or the new size
694-
// is not a power of 2.
695-
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
696-
697690
std::string toString() const;
698691

699692
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -762,42 +762,8 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
762762
StringAttr kLane = StringAttr::get(ctx, "lane");
763763
StringAttr kWarp = StringAttr::get(ctx, "warp");
764764
StringAttr kBlock = StringAttr::get(ctx, "block");
765-
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
766-
auto numDstRegs = dstLayout->getInDimSize(kRegister);
767-
// The `invertAndCompose` function will generate a layout that is injective
768-
// by assigning new output dimensions to free variables. For instance,
769-
// consider a scenario where `srcLayout` has a free variable in the lane
770-
// dimension, while `dstLayout` has two free variables in the lane
771-
// dimension and also a larger number of registers.
772-
// The injective form of `srcLayout` will add only a single additional row
773-
// to the transformation matrix, whereas the injective form of `dstLayout`
774-
// will add two additional rows. This discrepancy causes misleading results
775-
// because the matrices end up with a different number of rows.
776-
//
777-
// Take `dstLayout ⋅ srcLayout^-1` as an example:
778-
//
779-
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
780-
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
781-
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
782-
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
783-
// 1] → [n + 2, n + 1]
784-
//
785-
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
786-
// variable in registers, and the `(n + 2)`-th row represents the free
787-
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
788-
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
789-
// in two layouts do not correspond to the same free variable.
790-
//
791-
// To address this issue, we pad the free variables in `srcLayout` and
792-
// `dstLayout` to ensure they have the same number of registers. This
793-
// guarantees that the resulting matrices have the same number of rows,
794-
// ensuring consistency in the composition process.
795-
auto numRegs = std::max(numSrcRegs, numDstRegs);
796-
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
797-
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
798-
// comp describes the layout function to create dst from src.
799-
LinearLayout comp =
800-
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
765+
766+
auto comp = dstLayout->invertAndCompose(*srcLayout);
801767
// We try to quotient by the largest subspace first
802768
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
803769
for (auto dim : dims) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
311311
// TODO(Keren): implement warp shuffle instead of using the general
312312
// approach that uses shared memory
313313
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
314-
} else if (llvm::is_contained(dims, kRegister) ||
315-
dstLayout.getInDimSize(kRegister) !=
316-
srcLayout.getInDimSize(kRegister)) {
314+
} else if (llvm::is_contained(dims, kRegister)) {
317315
// Case 4. Transfer between values in the same thread, in which case we
318316
// simply reorder the elements of adaptor.getSrc().
319-
return transferWithinThread(
320-
op, dstLayout.getFreeVariableMasks()[kRegister],
321-
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
317+
return transferWithinThread(op, *conversion, adaptor, rewriter);
322318
} else {
323319
// Cast 5. The two layouts are equivalent. We should probably remove
324320
// these in RemoveLayoutConversion.
@@ -328,8 +324,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328324
}
329325

330326
LogicalResult
331-
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
332-
const LinearLayout &conversion, OpAdaptor adaptor,
327+
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
328+
OpAdaptor adaptor,
333329
ConversionPatternRewriter &rewriter) const {
334330
MLIRContext *ctx = op.getContext();
335331
auto loc = op.getLoc();
@@ -339,16 +335,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
339335
auto srcTy = op.getSrc().getType();
340336
auto dstTy = op.getType();
341337
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
342-
SmallVector<Value> outVals(numRegs);
343-
for (int i = 0; i < numRegs; i++) {
344-
// Remove free masks from the register index
345-
// For example, if idx = 0b00111, and masks = 0b00100, then we get
346-
// 0b00011. It means that register 7 (0b111) has the same value as
347-
// register 3 (0b011).
348-
auto idx = i & (~regMasks);
349-
auto srcIdx = conversion.hasInDim(kRegister)
350-
? conversion.apply({{kRegister, idx}}).begin()->second
351-
: idx;
338+
SmallVector<Value> outVals(conversion.getInDimSize(kRegister));
339+
for (int i = 0; i < outVals.size(); i++) {
340+
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
352341
outVals[i] = inVals[srcIdx];
353342
}
354343
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,

0 commit comments

Comments
 (0)