Skip to content

Commit e2dc77b

Browse files
authored
[LAYOUTS] Use LLs for Hopper whenever we wouldn't use ldmatrix (#5235)
The legacy path has some bugs for cases like `kWidth=1`. I'm starting to port Hopper to use LLs to try to isolate them.
1 parent deee78f commit e2dc77b

File tree

15 files changed

+164
-124
lines changed

15 files changed

+164
-124
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ python/triton/language/extra
2626
# Proton
2727
python/triton/profiler
2828

29+
# Pytest
30+
pytest.ini
31+
2932
# Instrumentation
3033
python/triton/instrumentation
3134

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Conversion/MLIRTypes.h"
55

66
namespace mlir::triton {
7+
78
class TargetInfoBase {
89
public:
910
virtual bool supportMaximumMinimum() const = 0;
@@ -37,6 +38,12 @@ class TargetInfoBase {
3738
pred);
3839
}
3940

41+
virtual bool canUseStMatrix(RankedTensorType tensorTy,
42+
ArrayRef<unsigned> repShape,
43+
ArrayRef<unsigned> paddedRepShape,
44+
ArrayRef<unsigned> order,
45+
int swizzleByteSize) const = 0;
46+
4047
virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
4148
Value ptr, Value val) const = 0;
4249

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,11 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
241241
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
242242
// bit width of the tensor in the future to support more flexible tensor
243243
// encodings
244-
std::optional<LinearLayout>
245-
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
246-
ArrayRef<unsigned> repShape,
247-
ArrayRef<unsigned> paddedRepShape,
248-
ArrayRef<unsigned> order, int swizzleByteSize);
244+
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
245+
ArrayRef<unsigned> repShape,
246+
ArrayRef<unsigned> paddedRepShape,
247+
ArrayRef<unsigned> order,
248+
int swizzleByteSize);
249249
} // namespace mlir::triton::gpu
250250

251251
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false.
360360
int k = (needTrans) ? matShape[0] : matShape[2];
361361
int vec = (order[0] == rank-1) ? k : m;
362362
int mmaStride = (order[0] == rank-1) ? m : k;
363-
int maxPhase = mmaStride / perPhase;
363+
int maxPhase = std::max(mmaStride / perPhase, 1);
364364
return get(context, vec, perPhase, maxPhase, order, CTALayout);
365365
}
366366

@@ -373,7 +373,7 @@ compared to 1*64 when the hasLeadingOffset is false.
373373
int k = needTrans ? matShape[1] : matShape[2];
374374
int vec = (order[0] == rank-1) ? n : k;
375375
int mmaStride = (order[0] == rank-1) ? k : n;
376-
int maxPhase = mmaStride / perPhase;
376+
int maxPhase = std::max(mmaStride / perPhase, 1);
377377
return get(context, vec, perPhase, maxPhase, order, CTALayout);
378378
}
379379

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
628628
dotOperandLayout.getOpIdx() == 0 &&
629629
mmaLayout.getWarpsPerCTA()[1] == 1 &&
630630
!cvtNeedsSharedMemory(parentTy, srcTy) &&
631-
(elementTypeSize == 16 || elementTypeSize == 8);
631+
(elementTypeSize == 16 || elementTypeSize == 8) &&
632+
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
632633
return ans;
633634
}
634635

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -376,28 +376,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
376376
// completed before we can remove the layoutIsOK check:
377377
// 1. Support for AMD's WMMA
378378
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
379-
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
380-
return !useLegacyMMAConversion;
381-
}
382379
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
383-
auto parent = dotOperand.getParent();
384-
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
385-
return false;
386-
}
387-
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
388-
if (nvidiaMma.isAmpere()) {
389-
return true;
390-
}
391-
}
392-
if (isa<AMDMfmaEncodingAttr>(parent)) {
393-
return true;
394-
}
395-
return false;
380+
layout = dotOperand.getParent();
396381
}
397-
if (isa<BlockedEncodingAttr>(layout)) {
398-
return true;
382+
383+
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
384+
return !useLegacyMMAConversion;
399385
}
400-
if (isa<LinearEncodingAttr>(layout)) {
386+
if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
401387
return true;
402388
}
403389
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -408,6 +394,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
408394
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
409395
return failure();
410396
}
397+
// FIXME [Dot LL] Remove this once we implement this trick in LLs
398+
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
399+
return failure();
400+
}
411401

412402
assert(cvtNeedsSharedMemory(srcTy, dstTy));
413403

@@ -498,34 +488,35 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
498488
// don't need to avoid duplicate writes.
499489
// Input dims: [reg, lane, warp]
500490
// Output dims: [offset, iteration]
501-
std::optional<LinearLayout> shmemStoreLayout =
502-
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
503-
scratchConfig.paddedRepShape, scratchConfig.order,
504-
/*swizzleByteSize=*/0);
505-
bool isStMatrix = shmemStoreLayout.has_value();
506-
if (!isStMatrix) {
507-
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
508-
}
509-
assert(shmemStoreLayout.has_value());
491+
bool isStMatrix = targetInfo.canUseStMatrix(
492+
op.getSrc().getType(), scratchConfig.repShape,
493+
scratchConfig.paddedRepShape, scratchConfig.order,
494+
/*swizzleByteSize=*/0);
495+
LinearLayout shmemStoreLayout =
496+
isStMatrix ? chooseStMatrixLayout(
497+
ctx, op.getSrc().getType(), scratchConfig.repShape,
498+
scratchConfig.paddedRepShape, scratchConfig.order,
499+
/*swizzleByteSize=*/0)
500+
: srcLayout.invertAndCompose(sharedLayout);
510501

511502
const int shmemAllocatedNumElems =
512503
getNumScratchElements(scratchConfig.paddedRepShape);
513-
assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems);
504+
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);
514505

515506
// Layout for the load from shmem to registers.
516507
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);
517508

518509
// Check that the `register` fully determines the `iteration`. That is,
519510
// each thread does exactly the same reads and writes to shmem on each
520511
// iteration, just with different input/output registers.
521-
assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock},
522-
{kIteration}));
512+
assert(
513+
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
523514
assert(
524515
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
525516

526517
// iteration -> registers
527518
SmallVector<SmallVector<int>> inRegsForIter =
528-
collectRegsForIter(ctx, *shmemStoreLayout);
519+
collectRegsForIter(ctx, shmemStoreLayout);
529520
SmallVector<SmallVector<int>> outRegsForIter =
530521
collectRegsForIter(ctx, shmemLoadLayout);
531522

@@ -582,7 +573,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
582573
return vecAddr;
583574
};
584575

585-
auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout,
576+
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
586577
{{kRegister, i32_val(0)},
587578
{kLane, laneId},
588579
{kWarp, warpId},
@@ -605,11 +596,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
605596

606597
// When using `stmatrix`, we can store `inVec` elements even if they are
607598
// not contiguous
608-
auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut()
599+
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
609600
: scratchConfig.inVec;
610601
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
611602
auto inRegSlice = inRegs[j];
612-
Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice);
603+
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
613604
SmallVector<Value> inValsVec;
614605
for (int k = 0; k < inVec; k++)
615606
inValsVec.push_back(inVals[inRegSlice + k]);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,34 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
138138

139139
// FIXME [Dot LL]
140140
// Do for all DotOperandEncodingAttr once we have LLs for all of them
141-
static bool isSupportedDotOpLayout(RankedTensorType type) {
142-
auto layout = type.getEncoding();
143-
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
144-
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
141+
static bool isSupportedDotOpLayout(MemDescType srcTy,
142+
RankedTensorType dstTy) {
143+
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
144+
auto dstLayout = dstTy.getEncoding();
145+
auto bitwidth = dstTy.getElementTypeBitWidth();
146+
auto rank = dstTy.getRank();
147+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
148+
auto vecWidth = 32 / bitwidth;
145149
auto kWidth = dot.getKWidth();
146-
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
147-
// - kWidth == 8
148-
// - kWidth == 4, bitwidth = 32
150+
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
149151
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
152+
auto needTrans = kOrder != srcLayout.getOrder()[0];
153+
auto canUseLdmatrix =
154+
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
155+
if (mma.isHopper()) {
156+
// I think we should be able to remove this condition, but it's here
157+
// as the legacy ldmatrix path does not support it
158+
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
159+
}
160+
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
161+
// though
162+
canUseLdmatrix &=
163+
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
164+
// To be removed in https://github.com/triton-lang/triton/pull/5154
150165
bool legacyLoweringIsBuggy =
151-
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
152-
return legacyLoweringIsBuggy && mma.isAmpere();
166+
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
167+
return (mma.isHopper() && !canUseLdmatrix) ||
168+
(mma.isAmpere() && legacyLoweringIsBuggy);
153169
}
154170
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
155171
return true;
@@ -162,12 +178,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
162178
ConversionPatternRewriter &rewriter) const override {
163179
MemDescType srcTy = op.getSrc().getType();
164180
RankedTensorType dstTy = op.getType();
165-
Attribute srcLayout = srcTy.getEncoding();
166181
Attribute dstLayout = dstTy.getEncoding();
167-
if (isa<SharedEncodingAttr>(srcLayout) &&
168-
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
169-
LinearEncodingAttr>(dstLayout) ||
170-
isSupportedDotOpLayout(dstTy))) {
182+
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
183+
LinearEncodingAttr>(dstLayout) ||
184+
isSupportedDotOpLayout(srcTy, dstTy)) {
171185
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
172186
rewriter);
173187
}
@@ -206,7 +220,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206220
auto dstTy = op.getResult().getType();
207221
auto dstShape = dstTy.getShape();
208222
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
209-
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
223+
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
210224
"Unexpected rank of ConvertLayout(shared->distributed)");
211225

212226
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
241241
llvm::report_fatal_error("Illegal shared layout");
242242
}
243243

244-
int vec = 8 * 16 / elemBitWidth;
245-
if (vec != shared.getVec()) {
246-
llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec
247-
<< ": " << shared << "\n";
248-
llvm::report_fatal_error("Illegal shared layout");
249-
}
244+
int vec = shared.getVec();
250245

251246
StringAttr colDimName = outDimNames[colDim];
252247
StringAttr rowDimName = outDimNames[rowDim];
@@ -858,40 +853,7 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
858853
}
859854

860855
namespace {
861-
862-
// TODO (Keren): Currently, we have more restrictions than necessary when using
863-
// stmatrix. These restrictions are retained from legacy code, and we could
864-
// relax some of them in the future.
865-
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
866-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
867-
int swizzleByteSize) {
868-
auto mmaLayout =
869-
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
870-
if (!mmaLayout || !mmaLayout.isHopper())
871-
return false;
872-
if (isa<PointerType>(tensorTy.getElementType()))
873-
return false;
874-
if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16)
875-
return false;
876-
if (order[0] != 1)
877-
return false;
878-
879-
auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape());
880-
if (tensorShapePerCTA.size() != 2)
881-
return false;
882-
auto numIterations = ceil<unsigned>(tensorShapePerCTA[1], repShape[1]) *
883-
ceil<unsigned>(tensorShapePerCTA[0], repShape[0]);
884-
if (numIterations > 1)
885-
return false;
886-
if (paddedRepShape[1] % 8 != 0)
887-
return false;
888-
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
889-
swizzleByteSize != 128)
890-
return false;
891-
return true;
892-
}
893-
894-
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
856+
LinearLayout chooseStMatrixLayoutLeadingOffset(
895857
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
896858
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
897859
int swizzleByteSize) {
@@ -962,7 +924,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
962924
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
963925
}
964926

965-
std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
927+
LinearLayout chooseStMatrixLayoutNoLeadingOffset(
966928
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
967929
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
968930
StringAttr kReg = S("register");
@@ -1002,15 +964,11 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
1002964

1003965
} // anonymous namespace
1004966

1005-
std::optional<LinearLayout>
1006-
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
1007-
ArrayRef<unsigned> repShape,
1008-
ArrayRef<unsigned> paddedRepShape,
1009-
ArrayRef<unsigned> order, int swizzleByteSize) {
1010-
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
1011-
swizzleByteSize))
1012-
return std::nullopt;
1013-
967+
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
968+
ArrayRef<unsigned> repShape,
969+
ArrayRef<unsigned> paddedRepShape,
970+
ArrayRef<unsigned> order,
971+
int swizzleByteSize) {
1014972
if (swizzleByteSize == 0)
1015973
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
1016974
paddedRepShape, order);

python/test/unit/language/test_core.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
dtypes_with_bfloat16,
2929
is_cuda,
3030
is_interpreter,
31+
is_hopper,
3132
is_hip,
3233
is_hip_cdna,
3334
is_hip_mi200,
@@ -195,7 +196,12 @@ def is_layout_applicable(layout) -> bool:
195196
if layout in common_layouts:
196197
return True
197198
elif is_cuda():
198-
return isinstance(layout, MmaLayout)
199+
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
200+
if not isinstance(mma_layout, MmaLayout):
201+
return False
202+
if mma_layout.version[0] >= 3 and not is_hopper():
203+
return False
204+
return True
199205
elif is_hip():
200206
target_arch = triton.runtime.driver.active.get_current_target().arch
201207
if "gfx11" in target_arch:
@@ -5246,6 +5252,9 @@ def kernel(Out):
52465252
# TODO: backend should be tested separately
52475253

52485254
layouts = [
5255+
MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]),
5256+
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
5257+
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1),
52495258
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
52505259
BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
52515260
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
@@ -5293,9 +5302,9 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
52935302

52945303
@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]])
52955304
@pytest.mark.parametrize("dtype", ['float16'])
5296-
@pytest.mark.parametrize("src_layout", layouts)
5305+
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
52975306
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
5298-
@pytest.mark.parametrize("dst_layout", layouts)
5307+
@pytest.mark.parametrize("dst_layout", filter_layouts(layouts))
52995308
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
53005309
if str(src_layout) == str(dst_layout):
53015310
pytest.skip()

0 commit comments

Comments
 (0)