Skip to content

Commit 75dcede

Browse files
authored
Merge OpenAI Triton commit 1d5fdfe (#2621)
This PR change the Triton base from 78c8054 to 1d5fdfe (Oct 28). Pass rate: `99.83%` Please do not squash and merge this PR.
2 parents 5b94131 + 0290b6c commit 75dcede

File tree

22 files changed

+252
-220
lines changed

22 files changed

+252
-220
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
default_stages: [pre-commit, pre-push, manual]
12
repos:
23
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.4.0
4+
rev: v5.0.0
45
hooks:
56
- id: check-symlinks
67
- id: destroyed-symlinks
@@ -17,12 +18,11 @@ repos:
1718
- id: debug-statements
1819

1920
- repo: https://github.com/astral-sh/ruff-pre-commit
20-
rev: v0.1.3
21+
rev: v0.7.1
2122
hooks:
2223
- id: ruff
2324
files: '^python/.*'
24-
args: ["--fix", "--line-length", "120"]
25-
stages: [pre-commit, pre-push, manual]
25+
args: ["--fix", "--exit-non-zero-on-fix"]
2626
exclude: |
2727
(?x)(
2828
^python/triton/runtime/.*|
@@ -31,18 +31,16 @@ repos:
3131
)
3232
3333
- repo: https://github.com/google/yapf
34-
rev: be72557
34+
rev: "7e21823"
3535
hooks:
3636
- id: yapf
3737
args: ["-p", "-i"]
38-
stages: [pre-commit, pre-push, manual]
3938
exclude: "python/test/unit/language/test_line_info.py"
4039

4140
- repo: https://github.com/pre-commit/mirrors-clang-format
42-
rev: v16.0.6
41+
rev: v19.1.2
4342
hooks:
4443
- id: clang-format
45-
stages: [pre-commit, pre-push, manual]
4644

4745
# Expand YAML anchors in files used by github workflows, because github can't
4846
# do this itself. This lets us use anchors, which avoids code duplication.

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,6 @@ bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
216216

217217
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218218

219-
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
220-
221219
// Return true if the src and dst layout match.
222220
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
223221
RankedTensorType dstTy);

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ namespace gpu {
1818
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
1919
Type ouType);
2020

21-
SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
22-
ConversionPatternRewriter &rewriter, Location loc,
23-
const LLVMTypeConverter *typeConverter);
24-
25-
SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
26-
ConversionPatternRewriter &rewriter, Location loc,
27-
const LLVMTypeConverter *typeConverter);
28-
2921
Type getElementType(Value value);
3022

3123
class MultipleOperandsRange
@@ -187,8 +179,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
187179
for (auto operand : adaptor.getOperands()) {
188180
auto argTy = op->getOperand(0).getType();
189181
auto subOperands = unpackLLElements(loc, operand, rewriter);
190-
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
191-
this->getTypeConverter());
182+
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
183+
this->getTypeConverter());
192184
allOperands.resize(subOperands.size());
193185
for (auto v : llvm::enumerate(subOperands))
194186
allOperands[v.index()].push_back(v.value());
@@ -215,7 +207,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
215207
}
216208
resultVals = maybeDeduplicate(op, resultVals);
217209
resultVals =
218-
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
210+
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
219211
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
220212
rewriter, resultTy);
221213
rewriter.replaceOp(op, view);

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,67 @@ inline Value getStructFromSharedMemoryObject(Location loc,
13881388
return llvmStruct;
13891389
}
13901390

1391+
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
1392+
// instructions to pack & unpack sub-word integers. A workaround is to
1393+
// store the results of tensors with dot operand encodings in i32 to
1394+
// facilitate instructions such as `ldmatrix`.
1395+
//
1396+
// TODO: Confirm if the problem is still there.
1397+
inline bool requiresI32Conversion(Type type) {
1398+
auto tensorTy = dyn_cast<RankedTensorType>(type);
1399+
if (!tensorTy)
1400+
return false;
1401+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
1402+
if (!dotOpEnc)
1403+
return false;
1404+
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
1405+
if (!(parent && parent.getVersionMajor() < 3))
1406+
return false;
1407+
return true;
1408+
}
1409+
1410+
inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
1411+
Type type, RewriterBase &rewriter,
1412+
Location loc,
1413+
const LLVMTypeConverter *typeConverter) {
1414+
if (!requiresI32Conversion(type))
1415+
return inValues;
1416+
Type eltTy =
1417+
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1418+
1419+
SmallVector<Value> outValues;
1420+
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
1421+
auto vecTy = vec_ty(eltTy, vecWidth);
1422+
for (int i = 0; i < inValues.size(); i += vecWidth) {
1423+
Value vec = undef(vecTy);
1424+
for (int j = 0; j < vecWidth; j++) {
1425+
vec = insert_element(vec, inValues[i + j], i32_val(j));
1426+
}
1427+
outValues.push_back(bitcast(vec, i32_ty));
1428+
}
1429+
return outValues;
1430+
}
1431+
1432+
inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
1433+
Type type, RewriterBase &rewriter,
1434+
Location loc,
1435+
const LLVMTypeConverter *typeConverter) {
1436+
if (!requiresI32Conversion(type))
1437+
return inValues;
1438+
Type eltTy =
1439+
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1440+
1441+
SmallVector<Value> outValues;
1442+
for (auto v : inValues) {
1443+
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
1444+
auto vec = bitcast(v, vecTy);
1445+
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
1446+
outValues.push_back(extract_element(vec, i32_val(i)));
1447+
}
1448+
}
1449+
return outValues;
1450+
}
1451+
13911452
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
13921453
RewriterBase &rewriter) {
13931454
assert(bool(llvmStruct) && "can not unpack null values");

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -700,15 +700,15 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
700700
}
701701

702702
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
703-
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
704-
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
705-
// subsumed by the linear-layout checks.
703+
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
704+
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
705+
// checks.
706706
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
707707
// supported yet in Triton's backend.
708708
return !cvtReordersRegisters(srcTy, dstTy) &&
709709
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
710710
!isBlockedToDotShortcut(srcTy, dstTy) &&
711-
!isMmaToDotShortcut(srcTy, dstTy) &&
711+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
712712
!isMfmaToDotShortcut(srcTy, dstTy);
713713
}
714714

@@ -719,20 +719,6 @@ bool atomicNeedsSharedMemory(Value value) {
719719
return true;
720720
}
721721

722-
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
723-
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
724-
return true;
725-
// dot_op<opIdx=0, parent=#mma> = #mma
726-
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
727-
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
728-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
729-
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
730-
mmaLayout.getWarpsPerCTA()[1] == 1 &&
731-
dotOperandLayout.getOpIdx() == 0 &&
732-
dotOperandLayout.getParent() == mmaLayout &&
733-
!srcTy.getElementType().isF32();
734-
}
735-
736722
namespace {
737723

738724
/// A data structure similar to SetVector but maintains

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328328
} else {
329329
// Cast 5. The two layouts are equivalent. We should probably remove
330330
// these in RemoveLayoutConversion.
331-
rewriter.replaceOp(op, adaptor.getSrc());
331+
auto dstCvt = requiresI32Conversion(dstTy);
332+
auto srcCvt = requiresI32Conversion(srcTy);
333+
if (dstCvt || srcCvt) {
334+
auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter);
335+
inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(),
336+
getTypeConverter());
337+
inVals =
338+
packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter());
339+
auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals,
340+
rewriter, op.getType());
341+
rewriter.replaceOp(op, res);
342+
} else {
343+
rewriter.replaceOp(op, adaptor.getSrc());
344+
}
332345
return success();
333346
}
334347
}
@@ -342,9 +355,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
342355
StringAttr kRegister = str_attr("register");
343356
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
344357

358+
auto srcTy = op.getSrc().getType();
359+
auto dstTy = op.getType();
345360
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
361+
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
346362
SmallVector<Value> outVals(numRegs);
347-
for (int i = 0; i < outVals.size(); i++) {
363+
for (int i = 0; i < numRegs; i++) {
348364
// Remove free masks from the register index
349365
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350366
// 0b00011. It means that register 7 (0b111) has the same value as
@@ -355,6 +371,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
355371
: idx;
356372
outVals[i] = inVals[srcIdx];
357373
}
374+
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
358375
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
359376
op.getType());
360377
rewriter.replaceOp(op, result);
@@ -386,9 +403,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
386403
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
387404
if (auto nvidiaMma =
388405
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
389-
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
390-
return false;
391-
}
392406
if (useLegacyMMAConversion) {
393407
return false;
394408
}
@@ -398,6 +412,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
398412
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
399413
return largeKWidth && nvidiaMma.isAmpere();
400414
}
415+
return false;
401416
}
402417
if (isa<BlockedEncodingAttr>(layout)) {
403418
return true;
@@ -439,6 +454,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
439454
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
440455
}
441456
}
457+
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
442458

443459
// Pretty sure this is the identity function ATM
444460
// It'd be better to simply call `quotient({kBlock})` and
@@ -458,22 +474,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
458474
}
459475
}
460476

461-
// FIXME [Dot LL]
462-
// We know it's just for largeKWidth case in Ampere
463-
// In this case, we need to pack the outputs into i32
464-
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
465-
auto concat = [&](Value a, Value b) {
466-
return or_(zext(i32_ty, bitcast(a, i16_ty)),
467-
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
468-
};
469-
470-
SmallVector<Value> outVals32(outVals.size() / 2);
471-
for (int i = 0; i < outVals32.size(); ++i) {
472-
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
473-
}
474-
outVals = outVals32;
475-
}
476-
477+
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
477478
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
478479
op.getType());
479480
rewriter.replaceOp(op, result);

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -103,51 +103,6 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
103103
llvm_unreachable("unimplemented code path");
104104
}
105105

106-
SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
107-
ConversionPatternRewriter &rewriter, Location loc,
108-
const LLVMTypeConverter *typeConverter) {
109-
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
110-
if (!tensorTy)
111-
return inValues;
112-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
113-
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
114-
return inValues;
115-
SmallVector<Value> outValues;
116-
for (auto v : inValues) {
117-
// cast i32 to appropriate eltType vector and extract elements
118-
auto eltType = typeConverter->convertType(tensorTy.getElementType());
119-
auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth());
120-
auto vec = bitcast(v, vecType);
121-
for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) {
122-
outValues.push_back(extract_element(vec, i32_val(i)));
123-
}
124-
}
125-
return outValues;
126-
}
127-
128-
SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
129-
ConversionPatternRewriter &rewriter, Location loc,
130-
const LLVMTypeConverter *typeConverter) {
131-
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
132-
if (!tensorTy)
133-
return inValues;
134-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
135-
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
136-
return inValues;
137-
SmallVector<Value> outValues;
138-
auto eltType = typeConverter->convertType(tensorTy.getElementType());
139-
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
140-
auto vecType = vec_ty(eltType, vecWidth);
141-
for (int i = 0; i < inValues.size(); i += vecWidth) {
142-
Value vec = undef(vecType);
143-
for (int j = 0; j < vecWidth; j++) {
144-
vec = insert_element(vec, inValues[i + j], i32_val(j));
145-
}
146-
outValues.push_back(bitcast(vec, i32_ty));
147-
}
148-
return outValues;
149-
}
150-
151106
int getNumElementsPerThreads(Type type,
152107
const LLVMTypeConverter *typeConverter) {
153108
int numElemsPerThread = 1;
@@ -500,7 +455,7 @@ struct ElementwiseInlineAsmOpConversion
500455
auto argTy = op->getOperand(0).getType();
501456
auto subOperands = unpackLLElements(loc, operand, rewriter);
502457
unpackedOperands.push_back(
503-
unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter()));
458+
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
504459
}
505460

506461
int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
@@ -560,10 +515,11 @@ struct ElementwiseInlineAsmOpConversion
560515
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
561516
/*ouType=*/op->getResult(i).getType());
562517
}
563-
auto packed = packI32(unpackedResults[i], op->getResult(i).getType(),
564-
rewriter, loc, getTypeConverter());
565-
outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter,
566-
op->getResult(i).getType()));
518+
auto dstTy = op->getResult(i).getType();
519+
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
520+
getTypeConverter());
521+
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
522+
rewriter, op->getResult(i).getType()));
567523
}
568524

569525
rewriter.replaceOp(op, outs);

0 commit comments

Comments
 (0)