Skip to content

Commit 7e401df

Browse files
jataylojungpark-mlirpeterbell10htyuantiagainst
authored
[AMD] rc/3.2.x cherry picks (triton-lang#5347)
Reverts triton-lang#5191 due to some mlir errors in pytorch unit tests Smaller set of cherry picks: - triton-lang#5308 (and previous LLVM upgrades) - triton-lang#5281 - triton-lang#4925 - triton-lang#5053 - triton-lang#5019 - triton-lang#4998 --------- Co-authored-by: Jungwook Park <[email protected]> Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Hongtao Yu <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Ilya V <[email protected]> Co-authored-by: Kyle Wang <[email protected]>
1 parent 2d8093c commit 7e401df

File tree

26 files changed

+739
-882
lines changed

26 files changed

+739
-882
lines changed

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,11 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
212212

213213
bool atomicNeedsSharedMemory(Value result);
214214

215-
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
215+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
216+
217+
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218+
219+
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
216220

217221
// Return true if the src and dst layout match.
218222
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ namespace gpu {
1818
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
1919
Type ouType);
2020

21+
SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
22+
ConversionPatternRewriter &rewriter, Location loc,
23+
const LLVMTypeConverter *typeConverter);
24+
25+
SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
26+
ConversionPatternRewriter &rewriter, Location loc,
27+
const LLVMTypeConverter *typeConverter);
28+
2129
Type getElementType(Value value);
2230

2331
class MultipleOperandsRange
@@ -179,8 +187,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
179187
for (auto operand : adaptor.getOperands()) {
180188
auto argTy = op->getOperand(0).getType();
181189
auto subOperands = unpackLLElements(loc, operand, rewriter);
182-
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
183-
this->getTypeConverter());
190+
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
191+
this->getTypeConverter());
184192
allOperands.resize(subOperands.size());
185193
for (auto v : llvm::enumerate(subOperands))
186194
allOperands[v.index()].push_back(v.value());
@@ -207,7 +215,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
207215
}
208216
resultVals = maybeDeduplicate(op, resultVals);
209217
resultVals =
210-
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
218+
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
211219
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
212220
rewriter, resultTy);
213221
rewriter.replaceOp(op, view);

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,67 +1388,6 @@ inline Value getStructFromSharedMemoryObject(Location loc,
13881388
return llvmStruct;
13891389
}
13901390

1391-
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
1392-
// instructions to pack & unpack sub-word integers. A workaround is to
1393-
// store the results of tensors with dot operand encodings in i32 to
1394-
// facilitate instructions such as `ldmatrix`.
1395-
//
1396-
// TODO: Confirm if the problem is still there.
1397-
inline bool requiresI32Conversion(Type type) {
1398-
auto tensorTy = dyn_cast<RankedTensorType>(type);
1399-
if (!tensorTy)
1400-
return false;
1401-
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
1402-
if (!dotOpEnc)
1403-
return false;
1404-
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
1405-
if (!(parent && parent.getVersionMajor() < 3))
1406-
return false;
1407-
return true;
1408-
}
1409-
1410-
inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
1411-
Type type, RewriterBase &rewriter,
1412-
Location loc,
1413-
const LLVMTypeConverter *typeConverter) {
1414-
if (!requiresI32Conversion(type))
1415-
return inValues;
1416-
Type eltTy =
1417-
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1418-
1419-
SmallVector<Value> outValues;
1420-
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
1421-
auto vecTy = vec_ty(eltTy, vecWidth);
1422-
for (int i = 0; i < inValues.size(); i += vecWidth) {
1423-
Value vec = undef(vecTy);
1424-
for (int j = 0; j < vecWidth; j++) {
1425-
vec = insert_element(vec, inValues[i + j], i32_val(j));
1426-
}
1427-
outValues.push_back(bitcast(vec, i32_ty));
1428-
}
1429-
return outValues;
1430-
}
1431-
1432-
inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
1433-
Type type, RewriterBase &rewriter,
1434-
Location loc,
1435-
const LLVMTypeConverter *typeConverter) {
1436-
if (!requiresI32Conversion(type))
1437-
return inValues;
1438-
Type eltTy =
1439-
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1440-
1441-
SmallVector<Value> outValues;
1442-
for (auto v : inValues) {
1443-
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
1444-
auto vec = bitcast(v, vecTy);
1445-
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
1446-
outValues.push_back(extract_element(vec, i32_val(i)));
1447-
}
1448-
}
1449-
return outValues;
1450-
}
1451-
14521391
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
14531392
RewriterBase &rewriter) {
14541393
assert(bool(llvmStruct) && "can not unpack null values");

include/triton/Tools/LinearLayout.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -679,13 +679,6 @@ class LinearLayout {
679679
// (i.e. every input bit affects the output).
680680
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
681681

682-
// Increase an input dimension without affecting the output dimension. The
683-
// added free variables are mapped to 0, ensuring that the new input
684-
// dimensions correspond directly to the existing output space. The function
685-
// errors out if `newInDimSize` is less than the current size or the new size
686-
// is not a power of 2.
687-
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
688-
689682
std::string toString() const;
690683

691684
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
113113
Attribute srcLayout = srcTy.getEncoding();
114114
Attribute dstLayout = dstTy.getEncoding();
115115

116-
assert(cvtNeedsSharedMemory(srcTy, dstTy));
116+
assert(!isMfmaToDotShortcut(srcTy, dstTy));
117117

118118
// FIXME This is NOT entirely correct
119119
// This should be getElemOrder, but we don't have such a method

lib/Analysis/Utility.cpp

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
536536
(elemTy.isInteger(8) && version >= 2);
537537
}
538538

539-
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
539+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
540540
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
541541
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
542542
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
@@ -605,6 +605,22 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
605605
return matrixDimsCompatible && bDimCompatible;
606606
}
607607

608+
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
609+
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
610+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
611+
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
612+
return false;
613+
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
614+
// improved. In addition, we can enable this shortcut for regular MFMA
615+
// layout when opIdx == 1.
616+
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
617+
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
618+
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
619+
dotOperandLayout.getParent() == mfmaLayout &&
620+
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
621+
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
622+
}
623+
608624
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
609625
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
610626
RankedTensorType dstTy) {
@@ -639,46 +655,8 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
639655
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
640656
if (!(srcLayout.has_value() && dstLayout.has_value()))
641657
return std::nullopt;
642-
StringAttr kRegister = StringAttr::get(ctx, "register");
643-
StringAttr kLane = StringAttr::get(ctx, "lane");
644-
StringAttr kWarp = StringAttr::get(ctx, "warp");
645-
StringAttr kBlock = StringAttr::get(ctx, "block");
646-
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
647-
auto numDstRegs = dstLayout->getInDimSize(kRegister);
648-
// The `invertAndCompose` function will generate a layout that is injective
649-
// by assigning new output dimensions to free variables. For instance,
650-
// consider a scenario where `srcLayout` has a free variable in the lane
651-
// dimension, while `dstLayout` has two free variables in the lane
652-
// dimension and also a larger number of registers.
653-
// The injective form of `srcLayout` will add only a single additional row
654-
// to the transformation matrix, whereas the injective form of `dstLayout`
655-
// will add two additional rows. This discrepancy causes misleading results
656-
// because the matrices end up with a different number of rows.
657-
//
658-
// Take `dstLayout ⋅ srcLayout^-1` as an example:
659-
//
660-
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
661-
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
662-
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
663-
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
664-
// 1] → [n + 2, n + 1]
665-
//
666-
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
667-
// variable in registers, and the `(n + 2)`-th row represents the free
668-
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
669-
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
670-
// in two layouts do not correspond to the same free variable.
671-
//
672-
// To address this issue, we pad the free variables in `srcLayout` and
673-
// `dstLayout` to ensure they have the same number of registers. This
674-
// guarantees that the resulting matrices have the same number of rows,
675-
// ensuring consistency in the composition process.
676-
auto numRegs = std::max(numSrcRegs, numDstRegs);
677-
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
678-
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
679658
// comp describes the layout function to create dst from src.
680-
LinearLayout comp =
681-
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
659+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
682660
// We try to quotient by the largest subspace first
683661
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
684662
for (auto dim : dims) {
@@ -715,14 +693,15 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
715693
}
716694

717695
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
718-
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
719-
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
720-
// checks.
696+
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
697+
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
698+
// subsumed by the linear-layout checks.
721699
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
722700
// supported yet in Triton's backend.
723701
return !cvtReordersRegisters(srcTy, dstTy) &&
724702
!isBlockedToDotShortcut(srcTy, dstTy) &&
725-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
703+
!isMmaToDotShortcut(srcTy, dstTy) &&
704+
!isMfmaToDotShortcut(srcTy, dstTy);
726705
}
727706

728707
bool atomicNeedsSharedMemory(Value value) {
@@ -732,6 +711,20 @@ bool atomicNeedsSharedMemory(Value value) {
732711
return true;
733712
}
734713

714+
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
715+
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
716+
return true;
717+
// dot_op<opIdx=0, parent=#mma> = #mma
718+
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
719+
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
720+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
721+
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
722+
mmaLayout.getWarpsPerCTA()[1] == 1 &&
723+
dotOperandLayout.getOpIdx() == 0 &&
724+
dotOperandLayout.getParent() == mmaLayout &&
725+
!srcTy.getElementType().isF32();
726+
}
727+
735728
namespace {
736729

737730
/// A data structure similar to SetVector but maintains

0 commit comments

Comments
 (0)