Skip to content

Commit c96c236

Browse files
committed
Merge branch 'main' into etiotto.remove_rewrite_tensor_ptr
2 parents 2d22907 + 6536edb commit c96c236

File tree

33 files changed

+580
-403
lines changed

33 files changed

+580
-403
lines changed

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0efa590d435d2b4aefcbad9014dd5fa75dcf8405
1+
33dce10ece5b38aa0ab76739b658cd980a6e3d8f

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,6 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
246246
ArrayRef<unsigned> repShape,
247247
ArrayRef<unsigned> paddedRepShape,
248248
ArrayRef<unsigned> order, int swizzleByteSize);
249-
250-
// FIXME
251-
// Exposing to use it in LinearLayoutConversionsTest.cpp
252-
// Remove it once we fully activate the DotOperand conversion via LLs
253-
class DotOperandEncodingAttr;
254-
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
255-
DotOperandEncodingAttr dot);
256249
} // namespace mlir::triton::gpu
257250

258251
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
781781

782782
InterfaceMethod<"Return shape per CTA.",
783783
"SmallVector<unsigned>",
784-
"getShapePerCTATileForDotOperands",
784+
"getShapePerCTATileForOperand",
785785
(ins "ArrayRef<int64_t>":$tensorShape,
786-
"unsigned":$opIdx)>,
786+
"int":$kWidth,
787+
"int":$opIdx)>,
787788

788789
InterfaceMethod<"Return total element size per thread for dot operands.",
789790
"unsigned",
790-
"getTotalElemsPerThreadForOperands",
791+
"getTotalElemsPerThreadForOperand",
791792
(ins "ArrayRef<int64_t>":$tensorShape,
792793
"Type":$eltTy,
793-
"unsigned":$kWidth,
794-
"unsigned":$opIdx)>,
794+
"int":$kWidth,
795+
"int":$opIdx)>,
795796

796797
InterfaceMethod<"Return size per thread for dot operands.",
797798
"SmallVector<unsigned>",
798-
"getSizePerThreadForOperands",
799-
(ins "unsigned":$opIdx)>,
799+
"getSizePerThreadForOperand",
800+
(ins "int":$opIdx,
801+
"int":$kWidth)>,
800802

801803
InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
802804
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
@@ -919,11 +921,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
919921
bool supportReduction() const {
920922
return true;
921923
}
922-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
923-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
924-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
925-
SmallVector<int64_t> getMFMAInstrShapeForOperands(int kWidth, int opIdx) const;
926-
SmallVector<int64_t> getMFMARepForOperands(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
924+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
925+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
926+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
927+
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
928+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
927929

928930
SmallVector<unsigned> getContigPerThread() {
929931
auto rank = getWarpsPerCTA().size();
@@ -1030,12 +1032,12 @@ Row | warp 0 warp 2
10301032
bool supportReduction() const {
10311033
return true;
10321034
}
1033-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1034-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1035-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1035+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1036+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1037+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10361038
SmallVector<int64_t> getElemsPerInstrForOperands() const;
1037-
SmallVector<int64_t> getRepForOperands(ArrayRef<int64_t> operandShape,
1038-
Type elemType, int kWidth, int opIdx) const;
1039+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
1040+
Type elemType, int kWidth, int opIdx) const;
10391041
static SmallVector<unsigned> getMNKDimPerInstr();
10401042

10411043
SmallVector<unsigned> getContigPerThread() {
@@ -1235,18 +1237,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12351237
SmallVector<int> getMMAv1Rep(int opIdx) const;
12361238
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12371239
int getMMAv1Vec(int opIdx) const;
1238-
SmallVector<int64_t> getMMAv2Rep(ArrayRef<int64_t> shape,
1239-
int bitwidth, int opIdx) const;
1240+
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
1241+
int bitwidth, int kWidth, int opIdx) const;
12401242

12411243
bool supportReduction() const {
12421244
if (isAmpere() || isHopper()) {
12431245
return true;
12441246
}
12451247
return false;
12461248
};
1247-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1248-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1249-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1249+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1250+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1251+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
12501252

12511253
SmallVector<unsigned> getContigPerThread() {
12521254
assert(isVolta() || isAmpere() || isHopper());
@@ -1361,7 +1363,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13611363
let genVerifyDecl = 1;
13621364
let extraClassDeclaration = extraDistributedDeclaration # [{
13631365
SmallVector<unsigned> getContigPerThread() {
1364-
return getSizePerThread();
1366+
auto rank = getWarpsPerCTA().size();
1367+
assert(rank == 2 || rank == 3);
1368+
SmallVector<unsigned> contigPerThread(rank, 1);
1369+
auto kWidth = getKWidth();
1370+
assert(kWidth != 0 && "Do not support kWidth=0");
1371+
if (getOpIdx() == 0)
1372+
contigPerThread[rank - 1] = kWidth;
1373+
else
1374+
contigPerThread[rank - 2] = kWidth;
1375+
return contigPerThread;
13651376
};
13661377
}];
13671378
}

lib/Analysis/Allocation.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
115115

116116
assert(!isMfmaToDotShortcut(srcTy, dstTy));
117117

118-
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
118+
// FIXME This is NOT entirely correct
119+
// This should be getElemOrder, but we don't have such a method
120+
// TODO Implement getElemOrder and make sure it's consistent with
121+
// getContigPerThread
122+
auto inOrd = gpu::getThreadOrder(srcLayout);
123+
auto outOrd = gpu::getThreadOrder(dstLayout);
119124
scratchConfig.order = outOrd;
120125

121126
unsigned srcContigPerThread =

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -288,60 +288,71 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
288288
return rewriter.notifyMatchFailure(
289289
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
290290
}
291+
LinearLayout srcLayout =
292+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
293+
LinearLayout dstLayout =
294+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
295+
296+
StringAttr kBlock = str_attr("block");
297+
StringAttr kWarp = str_attr("warp");
298+
StringAttr kLane = str_attr("lane");
299+
StringAttr kRegister = str_attr("register");
291300

292301
assert(to_vector(conversion->getInDimNames()) ==
293302
to_vector(conversion->getOutDimNames()));
294303
auto dims = conversion->getInDimNames();
295-
if (llvm::is_contained(dims, str_attr("block"))) {
304+
if (llvm::is_contained(dims, kBlock)) {
296305
// Case 1: Transfer between values in different CTAs.
297306
// This requires moving values through distributed shared memory.
298307
return rewriter.notifyMatchFailure(
299308
op, "NYI: Transfer between different CTAs");
300-
} else if (llvm::is_contained(dims, str_attr("warp"))) {
309+
} else if (llvm::is_contained(dims, kWarp)) {
301310
// Case 2: Transfer between values in the same CTA, in which case we move
302311
// values through shared memory.
303-
LinearLayout srcLayout =
304-
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
305-
LinearLayout dstLayout =
306-
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
307312
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
308-
} else if (llvm::is_contained(dims, str_attr("lane"))) {
313+
} else if (llvm::is_contained(dims, kLane)) {
309314
// Case 3. Transfer between values in the same warp, in which case we try
310315
// to move values using warp shuffles, though if the pattern is
311316
// complicated enough we may fall back to using shared memory
312317
// TODO(Keren): implement warp shuffle instead of using the general
313318
// approach that uses shared memory
314-
LinearLayout srcLayout =
315-
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
316-
LinearLayout dstLayout =
317-
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
318319
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
319-
} else if (llvm::is_contained(dims, str_attr("register"))) {
320+
} else if (llvm::is_contained(dims, kRegister) ||
321+
dstLayout.getInDimSize(kRegister) !=
322+
srcLayout.getInDimSize(kRegister)) {
320323
// Case 4. Transfer between values in the same thread, in which case we
321324
// simply reorder the elements of adaptor.getSrc().
322-
return transferWithinThread(op, *conversion, adaptor, rewriter);
325+
return transferWithinThread(
326+
op, dstLayout.getFreeVariableMasks()[kRegister],
327+
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
323328
} else {
324-
// The two layouts are equivalent. We should probably remove these in
325-
// RemoveLayoutConversion.
329+
// Cast 5. The two layouts are equivalent. We should probably remove
330+
// these in RemoveLayoutConversion.
326331
rewriter.replaceOp(op, adaptor.getSrc());
327332
return success();
328333
}
329334
}
330335

331336
LogicalResult
332-
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
333-
OpAdaptor adaptor,
337+
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
338+
const LinearLayout &conversion, OpAdaptor adaptor,
334339
ConversionPatternRewriter &rewriter) const {
335340
MLIRContext *ctx = op.getContext();
336341
auto loc = op.getLoc();
337342
StringAttr kRegister = str_attr("register");
338343
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
339344

340345
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
341-
SmallVector<Value> outVals;
342-
outVals.resize(conversion.getInDimSize(kRegister));
343-
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
344-
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
346+
SmallVector<Value> outVals(numRegs);
347+
for (int i = 0; i < outVals.size(); i++) {
348+
// Remove free masks from the register index
349+
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350+
// 0b00011. It means that register 7 (0b111) has the same value as
351+
// register 3 (0b011).
352+
auto idx = i & (~regMasks);
353+
auto srcIdx = conversion.hasInDim(kRegister)
354+
? conversion.apply({{kRegister, idx}}).begin()->second
355+
: idx;
345356
outVals[i] = inVals[srcIdx];
346357
}
347358
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
@@ -372,6 +383,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
372383
}
373384
return true;
374385
}
386+
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
387+
if (auto nvidiaMma =
388+
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
389+
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
390+
return false;
391+
}
392+
if (useLegacyMMAConversion) {
393+
return false;
394+
}
395+
// FIXME [Dot LL]
396+
// Enabling LL path for buggy kWidth path
397+
bool largeKWidth =
398+
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
399+
return largeKWidth && nvidiaMma.isAmpere();
400+
}
401+
}
375402
if (isa<BlockedEncodingAttr>(layout)) {
376403
return true;
377404
}
@@ -431,6 +458,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
431458
}
432459
}
433460

461+
// FIXME [Dot LL]
462+
// We know it's just for largeKWidth case in Ampere
463+
// In this case, we need to pack the outputs into i32
464+
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
465+
auto concat = [&](Value a, Value b) {
466+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
467+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
468+
};
469+
470+
SmallVector<Value> outVals32(outVals.size() / 2);
471+
for (int i = 0; i < outVals32.size(); ++i) {
472+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
473+
}
474+
outVals = outVals32;
475+
}
476+
434477
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
435478
op.getType());
436479
rewriter.replaceOp(op, result);

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9090
auto dstDotOp =
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
93+
// FIXME [Dot LL]
94+
// We support this one via LLs, as the LocalLoad path is buggy
95+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
96+
bool largeKWidth =
97+
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
98+
if (mma.isAmpere() && largeKWidth) {
99+
return;
100+
}
101+
}
102+
93103
Attribute sharedMemorySpace =
94104
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
95105
auto tmpType = MemDescType::get(

0 commit comments

Comments
 (0)