Skip to content

Commit 876ce90

Browse files
Merge commit '1cf7b1b31cde8c62611e421becd4648c7284d76c'
2 parents e30e00f + 1cf7b1b commit 876ce90

File tree

20 files changed

+289
-544
lines changed

20 files changed

+289
-544
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ namespace mlir::triton {
1515

1616
namespace gpu {
1717

18-
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
19-
Type ouType);
20-
2118
Type getElementType(Value value);
2219

2320
class MultipleOperandsRange
@@ -179,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
179176
for (auto operand : adaptor.getOperands()) {
180177
auto argTy = op->getOperand(0).getType();
181178
auto subOperands = unpackLLElements(loc, operand, rewriter);
182-
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
183-
this->getTypeConverter());
184179
allOperands.resize(subOperands.size());
185180
for (auto v : llvm::enumerate(subOperands))
186181
allOperands[v.index()].push_back(v.value());
@@ -201,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
201196
}
202197
it += curr.size();
203198
}
204-
if (op->getNumOperands() > 0) {
205-
auto argTy = op->getOperand(0).getType();
206-
resultVals = reorderValues(resultVals, argTy, resultTy);
207-
}
208199
resultVals = maybeDeduplicate(op, resultVals);
209-
resultVals =
210-
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
211200
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
212201
rewriter, resultTy);
213202
rewriter.replaceOp(op, view);

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 8 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,14 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
396396
// MXFP utilities
397397
// -----------------------------------------------------------------------
398398

399-
// Convert one int8, which contain, 2 packed mxfp4 values, into 2 bf16
400-
// standalone values and returns them as a pair for (high 4 bits, low 4 bits).
401-
std::pair<Value, Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter,
402-
Location loc, Value v);
399+
// Convert each value, which is an int8 containing 2 packed mxfp4 values,
400+
// into 2 standalone bf16 values
401+
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
402+
ArrayRef<Value> values);
403+
404+
// Scale a mxfp4 value by a given scale.
405+
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);
406+
403407
} // namespace LLVM
404408

405409
/* ------------------------------------ */
@@ -1397,67 +1401,6 @@ inline Value getStructFromSharedMemoryObject(Location loc,
13971401
return llvmStruct;
13981402
}
13991403

1400-
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
1401-
// instructions to pack & unpack sub-word integers. A workaround is to
1402-
// store the results of tensors with dot operand encodings in i32 to
1403-
// facilitate instructions such as `ldmatrix`.
1404-
//
1405-
// TODO: Confirm if the problem is still there.
1406-
inline bool requiresI32Conversion(Type type) {
1407-
auto tensorTy = dyn_cast<RankedTensorType>(type);
1408-
if (!tensorTy)
1409-
return false;
1410-
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
1411-
if (!dotOpEnc)
1412-
return false;
1413-
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
1414-
if (!(parent && parent.getVersionMajor() < 3))
1415-
return false;
1416-
return true;
1417-
}
1418-
1419-
inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
1420-
Type type, RewriterBase &rewriter,
1421-
Location loc,
1422-
const LLVMTypeConverter *typeConverter) {
1423-
if (!requiresI32Conversion(type))
1424-
return inValues;
1425-
Type eltTy =
1426-
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1427-
1428-
SmallVector<Value> outValues;
1429-
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
1430-
auto vecTy = vec_ty(eltTy, vecWidth);
1431-
for (int i = 0; i < inValues.size(); i += vecWidth) {
1432-
Value vec = undef(vecTy);
1433-
for (int j = 0; j < vecWidth; j++) {
1434-
vec = insert_element(vec, inValues[i + j], i32_val(j));
1435-
}
1436-
outValues.push_back(bitcast(vec, i32_ty));
1437-
}
1438-
return outValues;
1439-
}
1440-
1441-
inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
1442-
Type type, RewriterBase &rewriter,
1443-
Location loc,
1444-
const LLVMTypeConverter *typeConverter) {
1445-
if (!requiresI32Conversion(type))
1446-
return inValues;
1447-
Type eltTy =
1448-
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1449-
1450-
SmallVector<Value> outValues;
1451-
for (auto v : inValues) {
1452-
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
1453-
auto vec = bitcast(v, vecTy);
1454-
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
1455-
outValues.push_back(extract_element(vec, i32_val(i)));
1456-
}
1457-
}
1458-
return outValues;
1459-
}
1460-
14611404
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
14621405
RewriterBase &rewriter) {
14631406
assert(bool(llvmStruct) && "can not unpack null values");

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,8 +1212,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12121212
bool isAmpere() const;
12131213
bool isHopper() const;
12141214

1215-
unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;
1216-
12171215
// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
12181216
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;
12191217

@@ -1230,8 +1228,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12301228
SmallVector<int> getMMAv1Rep(int opIdx) const;
12311229
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12321230
int getMMAv1Vec(int opIdx) const;
1233-
SmallVector<int64_t> getMMAv2OrV3RepForOperand(ArrayRef<int64_t> shape,
1234-
int bitwidth, int kWidth, int opIdx) const;
1231+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
1232+
int bitwidth, int opIdx) const;
12351233

12361234
bool supportReduction() const {
12371235
if (isAmpere() || isHopper()) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
357357
: idx;
358358
outVals[i] = inVals[srcIdx];
359359
}
360-
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
361360
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
362361
op.getType());
363362
rewriter.replaceOp(op, result);
@@ -392,11 +391,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
392391
if (useLegacyMMAConversion) {
393392
return false;
394393
}
395-
// FIXME [Dot LL]
396-
// Enabling LL path for buggy kWidth path
397-
bool largeKWidth =
398-
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
399-
return largeKWidth && nvidiaMma.isAmpere();
394+
if (nvidiaMma.isAmpere()) {
395+
return true;
396+
}
400397
}
401398
return false;
402399
}
@@ -440,7 +437,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
440437
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
441438
}
442439
}
443-
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
444440

445441
// Pretty sure this is the identity function ATM
446442
// It'd be better to simply call `quotient({kBlock})` and
@@ -460,7 +456,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
460456
}
461457
}
462458

463-
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
464459
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
465460
op.getType());
466461
rewriter.replaceOp(op, result);

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,10 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9090
auto dstDotOp =
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
93-
// FIXME [Dot LL]
94-
// We support this one via LLs, as the LocalLoad path is buggy
95-
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
96-
bool largeKWidth =
97-
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
98-
if (mma.isAmpere() && largeKWidth) {
99-
return;
100-
}
93+
auto dotParent = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent());
94+
if (dotParent && dotParent.isAmpere()) {
95+
return;
10196
}
102-
10397
Attribute sharedMemorySpace =
10498
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
10599
auto tmpType = MemDescType::get(

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 7 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -11,138 +11,23 @@ using namespace mlir::triton::gpu;
1111

1212
namespace mlir::triton::gpu {
1313

14-
namespace {
15-
16-
bool isDotOpTensorAndPacked(Type srcTy) {
17-
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
18-
if (!tensorTy)
19-
return false;
20-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
21-
if (!encoding)
22-
return false;
23-
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
24-
// By code convention, values for Hopper's dotOp-encoded tensors are not
25-
// packed
26-
if (!parentEnc || parentEnc.isHopper())
27-
return false;
28-
return true;
29-
}
30-
31-
} // namespace
32-
3314
Type getElementType(Value value) {
3415
auto type = value.getType();
3516
if (auto tensorType = dyn_cast<RankedTensorType>(type))
3617
return tensorType.getElementType();
3718
return type;
3819
}
39-
// MMA encoding has a different order depending on the element's bit width;
40-
// reorder if we're in this case.
41-
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
42-
Type ouType) {
43-
auto inTensorTy = dyn_cast<RankedTensorType>(inType);
44-
auto ouTensorTy = dyn_cast<RankedTensorType>(ouType);
45-
if (!inTensorTy || !ouTensorTy)
46-
return values;
47-
auto inEncoding = dyn_cast<DotOperandEncodingAttr>(inTensorTy.getEncoding());
48-
auto ouEncoding = dyn_cast<DotOperandEncodingAttr>(ouTensorTy.getEncoding());
49-
assert(inEncoding == ouEncoding);
50-
if (!inEncoding)
51-
return values;
52-
// If the parent of the dot operand is in block encoding, we don't need to
53-
// reorder elements
54-
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
55-
if (!parentEncoding || parentEncoding.isHopper())
56-
return values;
57-
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
58-
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
59-
auto ouEltTy = ouTensorTy.getElementType();
60-
if (inBitWidth == ouBitWidth)
61-
return values;
62-
if (inBitWidth == 16 && ouBitWidth == 32) {
63-
// Register layout conversion:
64-
//
65-
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
66-
// [2, 3], [6, 7] [2], [3], [6], [7]
67-
//
68-
// Original access order:
69-
//
70-
// [0, 1], [2, 3], [4, 5], [6, 7]
71-
//
72-
// Transformed access order:
73-
//
74-
// [0], [2], [1], [3], [4], [6], [5], [7]
75-
SmallVector<Value> ret;
76-
for (unsigned i = 0; i < values.size(); i += 8) {
77-
ret.push_back(values[i]);
78-
ret.push_back(values[i + 2]);
79-
ret.push_back(values[i + 1]);
80-
ret.push_back(values[i + 3]);
81-
ret.push_back(values[i + 4]);
82-
ret.push_back(values[i + 6]);
83-
ret.push_back(values[i + 5]);
84-
ret.push_back(values[i + 7]);
85-
}
86-
return ret;
87-
}
88-
if (inBitWidth == 8 && ouBitWidth == 16) {
89-
// Register layout conversion:
90-
//
91-
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
92-
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
93-
//
94-
// Original access order:
95-
//
96-
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
97-
//
98-
// Transformed access order:
99-
//
100-
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
101-
SmallVector<Value> ret;
102-
for (unsigned i = 0; i < values.size(); i += 16) {
103-
ret.push_back(values[i]);
104-
ret.push_back(values[i + 1]);
105-
ret.push_back(values[i + 4]);
106-
ret.push_back(values[i + 5]);
107-
ret.push_back(values[i + 2]);
108-
ret.push_back(values[i + 3]);
109-
ret.push_back(values[i + 6]);
110-
ret.push_back(values[i + 7]);
111-
ret.push_back(values[i + 8]);
112-
ret.push_back(values[i + 9]);
113-
ret.push_back(values[i + 12]);
114-
ret.push_back(values[i + 13]);
115-
ret.push_back(values[i + 10]);
116-
ret.push_back(values[i + 11]);
117-
ret.push_back(values[i + 14]);
118-
ret.push_back(values[i + 15]);
119-
}
120-
return ret;
121-
}
122-
llvm_unreachable("unimplemented code path");
123-
}
12420

12521
int getNumElementsPerThreads(Type type,
12622
const LLVMTypeConverter *typeConverter) {
12723
int numElemsPerThread = 1;
128-
auto tensorTy = dyn_cast<RankedTensorType>(type);
129-
if (!tensorTy)
130-
return numElemsPerThread;
131-
auto structType =
132-
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
133-
if (structType) {
134-
numElemsPerThread = structType.getBody().size();
24+
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
25+
auto structType =
26+
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
27+
if (structType)
28+
numElemsPerThread = structType.getBody().size();
13529
}
136-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
137-
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
138-
return numElemsPerThread;
139-
auto eltType = tensorTy.getElementType();
140-
assert(eltType.getIntOrFloatBitWidth() <= 32 &&
141-
"Only support element type with bit width <= 32 in dot operand mma "
142-
"layout");
143-
// dot operand data are packed into i32 elements so use the following formula
144-
// to get the number of elements per thread.
145-
return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread;
30+
return numElemsPerThread;
14631
}
14732

14833
} // namespace mlir::triton::gpu
@@ -473,8 +358,7 @@ struct ElementwiseInlineAsmOpConversion
473358
for (auto operand : adaptor.getOperands()) {
474359
auto argTy = op->getOperand(0).getType();
475360
auto subOperands = unpackLLElements(loc, operand, rewriter);
476-
unpackedOperands.push_back(
477-
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
361+
unpackedOperands.push_back(subOperands);
478362
}
479363

480364
int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
@@ -527,16 +411,6 @@ struct ElementwiseInlineAsmOpConversion
527411
// Reorder and pack the results.
528412
SmallVector<Value> outs;
529413
for (int i = 0; i < unpackedResults.size(); i++) {
530-
// We reordered all the inputs so they match operand 0. Reorder the
531-
// outputs accordingly.
532-
if (op->getNumOperands() > 0) {
533-
unpackedResults[i] = reorderValues(
534-
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
535-
/*ouType=*/op->getResult(i).getType());
536-
}
537-
auto dstTy = op->getResult(i).getType();
538-
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
539-
getTypeConverter());
540414
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
541415
rewriter, op->getResult(i).getType()));
542416
}

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
173173
auto dstLayout = dstTy.getEncoding();
174174
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
175175
"Unexpected rank of ConvertLayout(shared->distributed)");
176-
auto inOrd = getOrder(srcSharedLayout);
177176

178177
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
179178
loc, adaptor.getSrc(),
@@ -183,7 +182,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
183182
SmallVector<Value> outVals = loadSharedToDistributed(
184183
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
185184

186-
outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter);
187185
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
188186
rewriter.replaceOp(op, result);
189187

0 commit comments

Comments
 (0)