Skip to content

Commit da569f1

Browse files
Merge commit '3f1d70fb4a0678b3535ad3fbd30d476de9970a81'
2 parents 3c09bfe + 3f1d70f commit da569f1

File tree

30 files changed

+230
-165
lines changed

30 files changed

+230
-165
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ python/triton/language/extra
3434
# Proton
3535
python/triton/profiler
3636

37+
# Pytest
38+
pytest.ini
39+
3740
# Instrumentation
3841
python/triton/instrumentation
3942

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Conversion/MLIRTypes.h"
55

66
namespace mlir::triton {
7+
78
class TargetInfoBase {
89
public:
910
virtual bool supportMaximumMinimum() const = 0;
@@ -37,6 +38,12 @@ class TargetInfoBase {
3738
pred);
3839
}
3940

41+
virtual bool canUseStMatrix(RankedTensorType tensorTy,
42+
ArrayRef<unsigned> repShape,
43+
ArrayRef<unsigned> paddedRepShape,
44+
ArrayRef<unsigned> order,
45+
int swizzleByteSize) const = 0;
46+
4047
virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
4148
Value ptr, Value val) const = 0;
4249

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,11 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
241241
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
242242
// bit width of the tensor in the future to support more flexible tensor
243243
// encodings
244-
std::optional<LinearLayout>
245-
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
246-
ArrayRef<unsigned> repShape,
247-
ArrayRef<unsigned> paddedRepShape,
248-
ArrayRef<unsigned> order, int swizzleByteSize);
244+
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
245+
ArrayRef<unsigned> repShape,
246+
ArrayRef<unsigned> paddedRepShape,
247+
ArrayRef<unsigned> order,
248+
int swizzleByteSize);
249249
} // namespace mlir::triton::gpu
250250

251251
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false.
360360
int k = (needTrans) ? matShape[0] : matShape[2];
361361
int vec = (order[0] == rank-1) ? k : m;
362362
int mmaStride = (order[0] == rank-1) ? m : k;
363-
int maxPhase = mmaStride / perPhase;
363+
int maxPhase = std::max(mmaStride / perPhase, 1);
364364
return get(context, vec, perPhase, maxPhase, order, CTALayout);
365365
}
366366

@@ -373,7 +373,7 @@ compared to 1*64 when the hasLeadingOffset is false.
373373
int k = needTrans ? matShape[1] : matShape[2];
374374
int vec = (order[0] == rank-1) ? n : k;
375375
int mmaStride = (order[0] == rank-1) ? k : n;
376-
int maxPhase = mmaStride / perPhase;
376+
int maxPhase = std::max(mmaStride / perPhase, 1);
377377
return get(context, vec, perPhase, maxPhase, order, CTALayout);
378378
}
379379

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
635635
dotOperandLayout.getOpIdx() == 0 &&
636636
mmaLayout.getWarpsPerCTA()[1] == 1 &&
637637
!cvtNeedsSharedMemory(parentTy, srcTy) &&
638-
(elementTypeSize == 16 || elementTypeSize == 8);
638+
(elementTypeSize == 16 || elementTypeSize == 8) &&
639+
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
639640
return ans;
640641
}
641642

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -380,24 +380,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
380380
return !useLegacyMMAConversion;
381381
}
382382
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
383-
auto parent = dotOperand.getParent();
384-
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
385-
return false;
386-
}
387-
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
388-
if (nvidiaMma.isAmpere()) {
389-
return true;
390-
}
391-
}
392-
if (isa<AMDMfmaEncodingAttr>(parent)) {
393-
return true;
383+
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(
384+
dotOperand.getParent())) {
385+
return !useLegacyMMAConversion;
394386
}
395387
return false;
396388
}
397-
if (isa<BlockedEncodingAttr>(layout)) {
398-
return true;
399-
}
400-
if (isa<LinearEncodingAttr>(layout)) {
389+
if (isa<BlockedEncodingAttr, LinearEncodingAttr>(layout)) {
401390
return true;
402391
}
403392
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -408,6 +397,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
408397
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
409398
return failure();
410399
}
400+
// FIXME [Dot LL] Remove this once we implement this trick in LLs
401+
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
402+
return failure();
403+
}
411404

412405
assert(cvtNeedsSharedMemory(srcTy, dstTy));
413406

@@ -498,34 +491,35 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
498491
// don't need to avoid duplicate writes.
499492
// Input dims: [reg, lane, warp]
500493
// Output dims: [offset, iteration]
501-
std::optional<LinearLayout> shmemStoreLayout =
502-
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
503-
scratchConfig.paddedRepShape, scratchConfig.order,
504-
/*swizzleByteSize=*/0);
505-
bool isStMatrix = shmemStoreLayout.has_value();
506-
if (!isStMatrix) {
507-
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
508-
}
509-
assert(shmemStoreLayout.has_value());
494+
bool isStMatrix = targetInfo.canUseStMatrix(
495+
op.getSrc().getType(), scratchConfig.repShape,
496+
scratchConfig.paddedRepShape, scratchConfig.order,
497+
/*swizzleByteSize=*/0);
498+
LinearLayout shmemStoreLayout =
499+
isStMatrix ? chooseStMatrixLayout(
500+
ctx, op.getSrc().getType(), scratchConfig.repShape,
501+
scratchConfig.paddedRepShape, scratchConfig.order,
502+
/*swizzleByteSize=*/0)
503+
: srcLayout.invertAndCompose(sharedLayout);
510504

511505
const int shmemAllocatedNumElems =
512506
getNumScratchElements(scratchConfig.paddedRepShape);
513-
assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems);
507+
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);
514508

515509
// Layout for the load from shmem to registers.
516510
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);
517511

518512
// Check that the `register` fully determines the `iteration`. That is,
519513
// each thread does exactly the same reads and writes to shmem on each
520514
// iteration, just with different input/output registers.
521-
assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock},
522-
{kIteration}));
515+
assert(
516+
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
523517
assert(
524518
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
525519

526520
// iteration -> registers
527521
SmallVector<SmallVector<int>> inRegsForIter =
528-
collectRegsForIter(ctx, *shmemStoreLayout);
522+
collectRegsForIter(ctx, shmemStoreLayout);
529523
SmallVector<SmallVector<int>> outRegsForIter =
530524
collectRegsForIter(ctx, shmemLoadLayout);
531525

@@ -582,7 +576,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
582576
return vecAddr;
583577
};
584578

585-
auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout,
579+
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
586580
{{kRegister, i32_val(0)},
587581
{kLane, laneId},
588582
{kWarp, warpId},
@@ -605,11 +599,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
605599

606600
// When using `stmatrix`, we can store `inVec` elements even if they are
607601
// not contiguous
608-
auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut()
602+
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
609603
: scratchConfig.inVec;
610604
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
611605
auto inRegSlice = inRegs[j];
612-
Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice);
606+
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
613607
SmallVector<Value> inValsVec;
614608
for (int k = 0; k < inVec; k++)
615609
inValsVec.push_back(inVals[inRegSlice + k]);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,34 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
138138

139139
// FIXME [Dot LL]
140140
// Do for all DotOperandEncodingAttr once we have LLs for all of them
141-
static bool isSupportedDotOpLayout(RankedTensorType type) {
142-
auto layout = type.getEncoding();
143-
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
144-
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
141+
static bool isSupportedDotOpLayout(MemDescType srcTy,
142+
RankedTensorType dstTy) {
143+
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
144+
auto dstLayout = dstTy.getEncoding();
145+
auto bitwidth = dstTy.getElementTypeBitWidth();
146+
auto rank = dstTy.getRank();
147+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
148+
auto vecWidth = 32 / bitwidth;
145149
auto kWidth = dot.getKWidth();
146-
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
147-
// - kWidth == 8
148-
// - kWidth == 4, bitwidth = 32
150+
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
149151
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
152+
auto needTrans = kOrder != srcLayout.getOrder()[0];
153+
auto canUseLdmatrix =
154+
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
155+
if (mma.isHopper()) {
156+
// I think we should be able to remove this condition, but it's here
157+
// as the legacy ldmatrix path does not support it
158+
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
159+
}
160+
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
161+
// though
162+
canUseLdmatrix &=
163+
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
164+
// To be removed in https://github.com/triton-lang/triton/pull/5154
150165
bool legacyLoweringIsBuggy =
151-
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
152-
return legacyLoweringIsBuggy && mma.isAmpere();
166+
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
167+
return (mma.isHopper() && !canUseLdmatrix) ||
168+
(mma.isAmpere() && legacyLoweringIsBuggy);
153169
}
154170
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
155171
return true;
@@ -162,12 +178,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
162178
ConversionPatternRewriter &rewriter) const override {
163179
MemDescType srcTy = op.getSrc().getType();
164180
RankedTensorType dstTy = op.getType();
165-
Attribute srcLayout = srcTy.getEncoding();
166181
Attribute dstLayout = dstTy.getEncoding();
167-
if (isa<SharedEncodingAttr>(srcLayout) &&
168-
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
169-
LinearEncodingAttr>(dstLayout) ||
170-
isSupportedDotOpLayout(dstTy))) {
182+
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
183+
LinearEncodingAttr>(dstLayout) ||
184+
isSupportedDotOpLayout(srcTy, dstTy)) {
171185
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
172186
rewriter);
173187
}
@@ -206,7 +220,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206220
auto dstTy = op.getResult().getType();
207221
auto dstShape = dstTy.getShape();
208222
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
209-
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
223+
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
210224
"Unexpected rank of ConvertLayout(shared->distributed)");
211225

212226
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
199199
return getResult();
200200
}
201201

202+
// Eliminate splat constant transpose ops.
203+
if (auto attr =
204+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSrc()))
205+
return attr.reshape(getType());
206+
202207
return {};
203208
}
204209

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
242242
llvm::report_fatal_error("Illegal shared layout");
243243
}
244244

245-
int vec = 8 * 16 / elemBitWidth;
246-
if (vec != shared.getVec()) {
247-
llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec
248-
<< ": " << shared << "\n";
249-
llvm::report_fatal_error("Illegal shared layout");
250-
}
245+
int vec = shared.getVec();
251246

252247
StringAttr colDimName = outDimNames[colDim];
253248
StringAttr rowDimName = outDimNames[rowDim];
@@ -862,40 +857,7 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
862857
}
863858

864859
namespace {
865-
866-
// TODO (Keren): Currently, we have more restrictions than necessary when using
867-
// stmatrix. These restrictions are retained from legacy code, and we could
868-
// relax some of them in the future.
869-
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
870-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
871-
int swizzleByteSize) {
872-
auto mmaLayout =
873-
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
874-
if (!mmaLayout || !mmaLayout.isHopper())
875-
return false;
876-
if (isa<PointerType>(tensorTy.getElementType()))
877-
return false;
878-
if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16)
879-
return false;
880-
if (order[0] != 1)
881-
return false;
882-
883-
auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape());
884-
if (tensorShapePerCTA.size() != 2)
885-
return false;
886-
auto numIterations = ceil<unsigned>(tensorShapePerCTA[1], repShape[1]) *
887-
ceil<unsigned>(tensorShapePerCTA[0], repShape[0]);
888-
if (numIterations > 1)
889-
return false;
890-
if (paddedRepShape[1] % 8 != 0)
891-
return false;
892-
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
893-
swizzleByteSize != 128)
894-
return false;
895-
return true;
896-
}
897-
898-
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
860+
LinearLayout chooseStMatrixLayoutLeadingOffset(
899861
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
900862
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
901863
int swizzleByteSize) {
@@ -966,7 +928,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
966928
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
967929
}
968930

969-
std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
931+
LinearLayout chooseStMatrixLayoutNoLeadingOffset(
970932
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
971933
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
972934
StringAttr kReg = S("register");
@@ -1006,15 +968,11 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
1006968

1007969
} // anonymous namespace
1008970

1009-
std::optional<LinearLayout>
1010-
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
1011-
ArrayRef<unsigned> repShape,
1012-
ArrayRef<unsigned> paddedRepShape,
1013-
ArrayRef<unsigned> order, int swizzleByteSize) {
1014-
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
1015-
swizzleByteSize))
1016-
return std::nullopt;
1017-
971+
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
972+
ArrayRef<unsigned> repShape,
973+
ArrayRef<unsigned> paddedRepShape,
974+
ArrayRef<unsigned> order,
975+
int swizzleByteSize) {
1018976
if (swizzleByteSize == 0)
1019977
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
1020978
paddedRepShape, order);

python/test/unit/language/test_core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
dtypes_with_bfloat16,
2929
is_cuda,
3030
is_interpreter,
31+
is_hopper,
3132
is_hip,
3233
is_hip_cdna,
3334
is_hip_mi200,
@@ -220,7 +221,12 @@ def is_layout_applicable(layout) -> bool:
220221
if layout in common_layouts:
221222
return True
222223
elif is_cuda():
223-
return isinstance(layout, MmaLayout)
224+
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
225+
if not isinstance(mma_layout, MmaLayout):
226+
return False
227+
if mma_layout.version[0] >= 3 and not is_hopper():
228+
return False
229+
return True
224230
elif is_hip():
225231
target_arch = triton.runtime.driver.active.get_current_target().arch
226232
if "gfx11" in target_arch:
@@ -5342,9 +5348,9 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
53425348

53435349
@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]])
53445350
@pytest.mark.parametrize("dtype", ['float16'])
5345-
@pytest.mark.parametrize("src_layout", layouts)
5351+
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
53465352
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
5347-
@pytest.mark.parametrize("dst_layout", layouts)
5353+
@pytest.mark.parametrize("dst_layout", filter_layouts(layouts))
53485354
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
53495355
if str(src_layout) == str(dst_layout):
53505356
pytest.xfail("Do not convert same layout")

0 commit comments

Comments
 (0)