Skip to content

Commit 2d8093c

Browse files
jataylojoviliastantiagainstzhanglx13Jokeren
authored
[AMD] release/3.2.x AMD perf cherry picks (triton-lang#5191)
Cherry pick list: - triton-lang#4925 - triton-lang#5053 - triton-lang#5019 - triton-lang#5002 - triton-lang#4935 - required additional cherry picks triton-lang#4991 and triton-lang#4951 - triton-lang#4998 - triton-lang#4925 - triton-lang#5281 - triton-lang#5308 - All previous LLVM hash PRs before triton-lang#5308 --------- Co-authored-by: Ilya V <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Lixun Zhang <[email protected]> Co-authored-by: Keren Zhou <[email protected]> Co-authored-by: Alexander Efimov <[email protected]> Co-authored-by: Kyle Wang <[email protected]> Co-authored-by: Jungwook Park <[email protected]> Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Hongtao Yu <[email protected]>
1 parent 35c6c7c commit 2d8093c

File tree

42 files changed

+1556
-971
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1556
-971
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b5cc222d7429fe6f18c787f633d5262fac2e676f
1+
1f20eee6dc367bd202895e3eedb03974a628ef16

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
212212

213213
bool atomicNeedsSharedMemory(Value result);
214214

215-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
216-
217-
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218-
219-
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
215+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
220216

221217
// Return true if the src and dst layout match.
222218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ namespace gpu {
1818
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
1919
Type ouType);
2020

21-
SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
22-
ConversionPatternRewriter &rewriter, Location loc,
23-
const LLVMTypeConverter *typeConverter);
24-
25-
SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
26-
ConversionPatternRewriter &rewriter, Location loc,
27-
const LLVMTypeConverter *typeConverter);
28-
2921
Type getElementType(Value value);
3022

3123
class MultipleOperandsRange
@@ -187,8 +179,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
187179
for (auto operand : adaptor.getOperands()) {
188180
auto argTy = op->getOperand(0).getType();
189181
auto subOperands = unpackLLElements(loc, operand, rewriter);
190-
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
191-
this->getTypeConverter());
182+
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
183+
this->getTypeConverter());
192184
allOperands.resize(subOperands.size());
193185
for (auto v : llvm::enumerate(subOperands))
194186
allOperands[v.index()].push_back(v.value());
@@ -215,7 +207,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
215207
}
216208
resultVals = maybeDeduplicate(op, resultVals);
217209
resultVals =
218-
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
210+
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
219211
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
220212
rewriter, resultTy);
221213
rewriter.replaceOp(op, view);

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,67 @@ inline Value getStructFromSharedMemoryObject(Location loc,
13881388
return llvmStruct;
13891389
}
13901390

1391+
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
1392+
// instructions to pack & unpack sub-word integers. A workaround is to
1393+
// store the results of tensors with dot operand encodings in i32 to
1394+
// facilitate instructions such as `ldmatrix`.
1395+
//
1396+
// TODO: Confirm if the problem is still there.
1397+
inline bool requiresI32Conversion(Type type) {
1398+
auto tensorTy = dyn_cast<RankedTensorType>(type);
1399+
if (!tensorTy)
1400+
return false;
1401+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
1402+
if (!dotOpEnc)
1403+
return false;
1404+
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
1405+
if (!(parent && parent.getVersionMajor() < 3))
1406+
return false;
1407+
return true;
1408+
}
1409+
1410+
inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
1411+
Type type, RewriterBase &rewriter,
1412+
Location loc,
1413+
const LLVMTypeConverter *typeConverter) {
1414+
if (!requiresI32Conversion(type))
1415+
return inValues;
1416+
Type eltTy =
1417+
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1418+
1419+
SmallVector<Value> outValues;
1420+
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
1421+
auto vecTy = vec_ty(eltTy, vecWidth);
1422+
for (int i = 0; i < inValues.size(); i += vecWidth) {
1423+
Value vec = undef(vecTy);
1424+
for (int j = 0; j < vecWidth; j++) {
1425+
vec = insert_element(vec, inValues[i + j], i32_val(j));
1426+
}
1427+
outValues.push_back(bitcast(vec, i32_ty));
1428+
}
1429+
return outValues;
1430+
}
1431+
1432+
inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
1433+
Type type, RewriterBase &rewriter,
1434+
Location loc,
1435+
const LLVMTypeConverter *typeConverter) {
1436+
if (!requiresI32Conversion(type))
1437+
return inValues;
1438+
Type eltTy =
1439+
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1440+
1441+
SmallVector<Value> outValues;
1442+
for (auto v : inValues) {
1443+
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
1444+
auto vec = bitcast(v, vecTy);
1445+
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
1446+
outValues.push_back(extract_element(vec, i32_val(i)));
1447+
}
1448+
}
1449+
return outValues;
1450+
}
1451+
13911452
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
13921453
RewriterBase &rewriter) {
13931454
assert(bool(llvmStruct) && "can not unpack null values");

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,10 @@ def TT_ReduceOp: TT_Op<"reduce",
727727
llvm::SmallVector<RankedTensorType> getInputTypes();
728728
llvm::SmallVector<Type> getElementTypes();
729729
unsigned getNumOperands();
730+
731+
// Returns the CombineOp iff this ReduceOp's region contains only
732+
// one CombineOp other than the return, or nullptr if not applicable.
733+
::mlir::Operation *getSingleCombiner();
730734
}];
731735
}
732736

include/triton/Tools/LinearLayout.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ class LinearLayout {
679679
// (i.e. every input bit affects the output).
680680
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
681681

682+
// Increase an input dimension without affecting the output dimension. The
683+
// added free variables are mapped to 0, ensuring that the new input
684+
// dimensions correspond directly to the existing output space. The function
685+
// errors out if `newInDimSize` is less than the current size or the new size
686+
// is not a power of 2.
687+
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
688+
682689
std::string toString() const;
683690

684691
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
113113
Attribute srcLayout = srcTy.getEncoding();
114114
Attribute dstLayout = dstTy.getEncoding();
115115

116-
assert(!isMfmaToDotShortcut(srcTy, dstTy));
116+
assert(cvtNeedsSharedMemory(srcTy, dstTy));
117117

118118
// FIXME This is NOT entirely correct
119119
// This should be getElemOrder, but we don't have such a method

lib/Analysis/Utility.cpp

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
536536
(elemTy.isInteger(8) && version >= 2);
537537
}
538538

539-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
539+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
540540
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
541541
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
542542
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
@@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
605605
return matrixDimsCompatible && bDimCompatible;
606606
}
607607

608-
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
609-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
610-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
611-
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
612-
return false;
613-
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
614-
// improved. In addition, we can enable this shortcut for regular MFMA
615-
// layout when opIdx == 1.
616-
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
617-
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
618-
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
619-
dotOperandLayout.getParent() == mfmaLayout &&
620-
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
621-
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
622-
}
623-
624608
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
625609
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
626610
RankedTensorType dstTy) {
@@ -655,8 +639,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
655639
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
656640
if (!(srcLayout.has_value() && dstLayout.has_value()))
657641
return std::nullopt;
642+
StringAttr kRegister = StringAttr::get(ctx, "register");
643+
StringAttr kLane = StringAttr::get(ctx, "lane");
644+
StringAttr kWarp = StringAttr::get(ctx, "warp");
645+
StringAttr kBlock = StringAttr::get(ctx, "block");
646+
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
647+
auto numDstRegs = dstLayout->getInDimSize(kRegister);
648+
// The `invertAndCompose` function will generate a layout that is injective
649+
// by assigning new output dimensions to free variables. For instance,
650+
// consider a scenario where `srcLayout` has a free variable in the lane
651+
// dimension, while `dstLayout` has two free variables in the lane
652+
// dimension and also a larger number of registers.
653+
// The injective form of `srcLayout` will add only a single additional row
654+
// to the transformation matrix, whereas the injective form of `dstLayout`
655+
// will add two additional rows. This discrepancy causes misleading results
656+
// because the matrices end up with a different number of rows.
657+
//
658+
// Take `dstLayout ⋅ srcLayout^-1` as an example:
659+
//
660+
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
661+
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
662+
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
663+
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
664+
// 1] → [n + 2, n + 1]
665+
//
666+
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
667+
// variable in registers, and the `(n + 2)`-th row represents the free
668+
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
669+
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
670+
// in two layouts do not correspond to the same free variable.
671+
//
672+
// To address this issue, we pad the free variables in `srcLayout` and
673+
// `dstLayout` to ensure they have the same number of registers. This
674+
// guarantees that the resulting matrices have the same number of rows,
675+
// ensuring consistency in the composition process.
676+
auto numRegs = std::max(numSrcRegs, numDstRegs);
677+
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
678+
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
658679
// comp describes the layout function to create dst from src.
659-
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
680+
LinearLayout comp =
681+
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
660682
// We try to quotient by the largest subspace first
661683
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
662684
for (auto dim : dims) {
@@ -693,15 +715,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
693715
}
694716

695717
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
696-
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
697-
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
698-
// subsumed by the linear-layout checks.
718+
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
719+
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
720+
// checks.
699721
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
700722
// supported yet in Triton's backend.
701723
return !cvtReordersRegisters(srcTy, dstTy) &&
702724
!isBlockedToDotShortcut(srcTy, dstTy) &&
703-
!isMmaToDotShortcut(srcTy, dstTy) &&
704-
!isMfmaToDotShortcut(srcTy, dstTy);
725+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
705726
}
706727

707728
bool atomicNeedsSharedMemory(Value value) {
@@ -711,20 +732,6 @@ bool atomicNeedsSharedMemory(Value value) {
711732
return true;
712733
}
713734

714-
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
715-
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
716-
return true;
717-
// dot_op<opIdx=0, parent=#mma> = #mma
718-
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
719-
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
720-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
721-
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
722-
mmaLayout.getWarpsPerCTA()[1] == 1 &&
723-
dotOperandLayout.getOpIdx() == 0 &&
724-
dotOperandLayout.getParent() == mmaLayout &&
725-
!srcTy.getElementType().isF32();
726-
}
727-
728735
namespace {
729736

730737
/// A data structure similar to SetVector but maintains

0 commit comments

Comments
 (0)