Skip to content

Commit 1af120b

Browse files
Merge commit '0560390b3b04286515b34c060188b9d77cb5e1b1'
2 parents cd4527d + 0560390 commit 1af120b

File tree

35 files changed

+375
-426
lines changed

35 files changed

+375
-426
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#include <unordered_map>
1616

1717
// LinearLayoutCache Utils
18-
using CacheKey = std::tuple<std::vector<int64_t>, mlir::Attribute>;
18+
using CacheKey =
19+
std::tuple<std::vector<int64_t>, mlir::Attribute, std::vector<int64_t>>;
1920

2021
namespace llvm {
2122
template <typename T> size_t hash_value(const std::vector<T> &vec) {

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class MemDescType;
4747
// elemBitWidth is the bit width of one element in the layout. This is required
4848
// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e.
4949
// shared layouts with nvmma_shared layout) but is otherwise unused.
50-
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
50+
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
51+
ArrayRef<int64_t> allocationShape);
5152
LinearLayout toLinearLayout(RankedTensorType type);
5253
LinearLayout toLinearLayout(MemDescType type);
5354
LinearLayout toLinearLayout(TensorOrMemDesc type);

include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def TritonGPU_Dialect : Dialect {
2222
let extraClassDeclaration = [{
2323
void registerTypes();
2424

25-
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
25+
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout, ArrayRef<int64_t> allocationShape);
2626
LinearEncodingAttr toLinearEncoding(ArrayRef<int64_t> shape, Attribute layout);
2727

2828
static int getNumCTAs(ModuleOp mod);

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,27 @@ def TTG_MemDescReshapeOp : TTG_Op<"memdesc_reshape", [Pure,
273273
}];
274274

275275
let arguments = (ins TTG_MemDescType:$src);
276+
277+
let builders = [
278+
OpBuilder<(ins "Value":$src, "ArrayRef<int64_t>":$shape),
279+
[{
280+
MemDescType dstTy;
281+
auto srcTy = cast<MemDescType>(src.getType());
282+
auto result = inferReturnTypes($_builder.getContext(),
283+
$_builder.getUnknownLoc(),
284+
srcTy, shape, dstTy);
285+
assert(succeeded(result) && "failed to infer return types");
286+
build($_builder, $_state, dstTy, src);
287+
}]>
288+
];
289+
let extraClassDeclaration = [{
290+
static LogicalResult inferReturnTypes(MLIRContext *context,
291+
std::optional<Location> loc,
292+
MemDescType srcTy,
293+
ArrayRef<int64_t> dstShape,
294+
MemDescType &inferredReturnType);
295+
}];
296+
276297
let results = (outs TTG_MemDescType:$result);
277298

278299
let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";

include/triton/Tools/LinearLayout.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ class LinearLayout {
325325
bases;
326326

327327
llvm::MapVector<StringAttr, int32_t /*size*/> outDims;
328-
bool surjective = true;
328+
int32_t rank = 0;
329329

330330
public:
331331
using BasesT = decltype(bases);
@@ -425,10 +425,11 @@ class LinearLayout {
425425
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>> bases,
426426
ArrayRef<std::pair<StringAttr, int32_t>> outDims, bool requireSurjective);
427427

428-
bool isSurjective() const { return surjective; }
428+
bool isSurjective() const { return rank == getTotalOutDimSizeLog2(); }
429+
bool isInjective() const { return rank == getTotalInDimSizeLog2(); }
429430

430431
bool isInvertible() const {
431-
return surjective && getTotalInDimSize() == getTotalOutDimSize();
432+
return isSurjective() && getTotalInDimSize() == getTotalOutDimSize();
432433
}
433434

434435
const BasesT &getBases() const { return bases; }

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ namespace {
4040
LinearLayout getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
4141
LinearLayout regLayout,
4242
triton::gpu::SharedEncodingTrait dstEnc,
43-
int elemBitWidth) {
43+
int elemBitWidth,
44+
ArrayRef<int64_t> allocShape) {
4445
StringAttr kBlock = StringAttr::get(ctx, ("block"));
4546
int rank = shape.size();
4647

47-
LinearLayout sharedLayout = triton::gpu::toLinearLayout(shape, dstEnc);
48+
LinearLayout sharedLayout =
49+
triton::gpu::toLinearLayout(shape, dstEnc, allocShape);
4850
auto sharedOrder = triton::gpu::getOrder(dstEnc, shape);
4951

5052
// sharedLayout's in-dims are currently (offset, block). Reshape to
@@ -399,7 +401,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
399401
MLIRContext *ctx = rewriter.getContext();
400402
auto shape = type.getShape();
401403

402-
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
404+
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout, {});
403405

404406
StringAttr kRegister = str_attr("register");
405407
StringAttr kLane = str_attr("lane");
@@ -524,7 +526,7 @@ SmallVector<Value> getSmemVecAddrVec(
524526
sharedEnc)) {
525527
auto regToSharedSwizzledLayout =
526528
getRegToSharedLayout(ctx, shape, regLayout, swizzledSharedEnc,
527-
elemLlvmTy.getIntOrFloatBitWidth());
529+
elemLlvmTy.getIntOrFloatBitWidth(), allocShape);
528530
auto smemOrder = swizzledSharedEnc.getOrder();
529531

530532
auto swizzledIndicesVec =
@@ -680,9 +682,9 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
680682
bool isStore = !valsArray.empty();
681683
auto b = TritonLLVMOpBuilder(loc, rewriter);
682684

683-
auto emitCpAsync = [&](ConversionPatternRewriter &rewriter, Location loc,
684-
ArrayRef<Value> vals, Value shmemAddr, int idx,
685-
VectorType vecTy) -> SmallVector<Value> {
685+
auto emitLdSt = [&](ConversionPatternRewriter &rewriter, Location loc,
686+
ArrayRef<Value> vals, Value shmemAddr, int idx,
687+
VectorType vecTy) -> SmallVector<Value> {
686688
auto length = vecTy.getNumElements();
687689
if (isStore) {
688690
Value valsVec =
@@ -698,7 +700,7 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
698700
}
699701
};
700702
return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase, rewriter,
701-
targetInfo, {}, emitCpAsync);
703+
targetInfo, {}, emitLdSt);
702704
}
703705

704706
SmallVector<Value> lowerLdSt(
@@ -880,11 +882,13 @@ bool emitTransferBetweenRegistersAndShared(
880882
auto allocShape = sharedTy.getAllocShape();
881883
auto invertAllocSharedLayout = LinearLayout::empty();
882884
if (!paddedLayout) {
883-
// For now this is only needed for the cases where we have swizzling.
884-
invertAllocSharedLayout =
885-
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
886-
sharedTy.getEncoding())
887-
.pseudoinvert();
885+
// This is the legacy way of doing things that's much more ad-hoc
886+
// For generic shared layouts it may or may not be correct
887+
auto allocShape = sharedTy.getAllocShape();
888+
auto trimShape = allocShape.take_back(sharedTy.getRank());
889+
invertAllocSharedLayout = triton::gpu::toLinearLayout(
890+
trimShape, sharedTy.getEncoding(), trimShape)
891+
.pseudoinvert();
888892
}
889893

890894
int numElems = regToSharedLayout.getInDimSize(kRegister);
@@ -1494,7 +1498,7 @@ delinearize(RewriterBase &rewriter, Location loc,
14941498
triton::gpu::DistributedEncodingTrait layout,
14951499
ArrayRef<int64_t> shape, StringAttr dimName, Value linear) {
14961500
auto b = TritonLLVMOpBuilder(loc, rewriter);
1497-
auto ll = triton::gpu::toLinearLayout(shape, layout);
1501+
auto ll = triton::gpu::toLinearLayout(shape, layout, {});
14981502
auto linearLayout =
14991503
triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll);
15001504
assert(ll.hasInDim(dimName));

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -471,35 +471,24 @@ struct MemDescSubviewOpConversion
471471
// newBase = base + offset
472472
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
473473
llvmElemTy, rewriter);
474-
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
475474
SmallVector<Value> opOffsetVals = op.getOffsets();
476475
// We assume we always create a subview of the last dimensions
477-
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
478-
smemStrides.end());
479476
// Compute total offset
480-
SmallVector<Value> offsetVals;
481-
auto destRank = op.getResult().getType().getRank();
482-
auto rankReduced = srcTy.getRank() - destRank;
483-
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
484-
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
485-
}
477+
auto rankReduced = srcTy.getRank() - destTy.getRank();
486478

487479
Value offset;
488480
if (rankReduced || (destTy.getRank() == 1 && destTy.getDimSize(0) == 1)) {
481+
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
482+
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
483+
smemStrides.end());
489484
// We are splitting the pipelining dimension which may not be a power of 2
490485
// so we can't use LinearLayouts
491486
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
492487
} else {
493488
auto dimNames = standardOutDimNames(ctx, opOffsetVals.size());
494489
SmallVector<std::pair<StringAttr, Value>> logicalOffsets;
495-
// This assumes the subviews are additive, in the sense that we can
496-
// compute the offset of one and an add it to the offset of the previous
497-
// one we computed. We check for this in the verifier.
498-
for (int i = 0; i < rankReduced; i++) {
499-
logicalOffsets.push_back({dimNames[i], b.i32_val(0)});
500-
}
501-
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
502-
logicalOffsets.push_back({dimNames[i], offsetVals[i - rankReduced]});
490+
for (auto [dim, offset] : llvm::zip(dimNames, opOffsetVals)) {
491+
logicalOffsets.push_back({dim, offset});
503492
}
504493
auto ll = toLinearLayout(srcTy);
505494
// Checked in the verifier.
@@ -517,6 +506,11 @@ struct MemDescSubviewOpConversion
517506
offset = b.add(offset, padOffset);
518507
}
519508

509+
SmallVector<Value> offsetVals;
510+
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
511+
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
512+
}
513+
520514
auto base = smemObj.getBase();
521515
auto elemPtrTy = base.getType();
522516
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@ namespace gpu {
4242

4343
LinearEncodingAttr TritonGPUDialect::toLinearEncoding(ArrayRef<int64_t> shape,
4444
Attribute layout) {
45-
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout};
45+
// LinearEncoding is a DistributedLayout
46+
std::vector<int64_t> allocationShape;
47+
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout,
48+
allocationShape};
4649
if (auto result = leCache.get(key)) {
4750
return *result;
4851
}
49-
auto linearLayout = toLinearLayout(shape, layout);
52+
auto linearLayout = toLinearLayout(shape, layout, {});
5053
auto linearEncoding =
5154
LinearEncodingAttr::get(layout.getContext(), std::move(linearLayout));
5255
leCache.set(key, linearEncoding);
@@ -2386,7 +2389,7 @@ struct TritonGPUInferLayoutInterface
23862389
return success();
23872390
}
23882391

2389-
auto ll = toLinearLayout(shape, operandEncoding);
2392+
auto ll = toLinearLayout(shape, operandEncoding, {});
23902393
auto transposedLl = transposeLinearLayout(ll, order);
23912394
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));
23922395
return success();
@@ -2483,6 +2486,39 @@ struct TritonGPUInferLayoutInterface
24832486
Attribute srcEnc,
24842487
ArrayRef<int64_t> dstShape,
24852488
Attribute &dstEnc) const {
2489+
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
2490+
// TODO: supporting reshape of CTA layouts is non-trivial.
2491+
if (getNumCTAs(mmaEncoding) > 1)
2492+
return failure();
2493+
int innerDimDst =
2494+
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
2495+
int innerDimSrc =
2496+
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
2497+
// For now disallow reshape of the inner dimension.
2498+
if (innerDimDst != innerDimSrc)
2499+
return failure();
2500+
auto *ctx = srcEnc.getContext();
2501+
2502+
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
2503+
auto CTALayout = CTALayoutAttr::get(
2504+
ctx,
2505+
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
2506+
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
2507+
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
2508+
dstEnc = NVMMASharedEncodingAttr::get(
2509+
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
2510+
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
2511+
CTALayout);
2512+
// Big guns, check linear layouts are equivalent
2513+
// We disallow reshaping memdesc_subviews in the verifier
2514+
// We disallow reshaping memdesc_subviews in the verifier
2515+
auto srcLL = toLinearLayout(srcShape, srcEnc, srcShape);
2516+
auto dstLL = toLinearLayout(dstShape, dstEnc, dstShape);
2517+
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
2518+
return failure();
2519+
}
2520+
return success();
2521+
}
24862522
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
24872523
if (!src) {
24882524
return failure();
@@ -2730,6 +2766,10 @@ struct TritonGPUInferLayoutInterface
27302766
if (succeeded(result)) {
27312767
return result;
27322768
}
2769+
if (!isa<DistributedEncodingTrait>(srcEnc)) {
2770+
return emitOptionalError(loc,
2771+
"Failed MemDescReshapeOp encoding inference");
2772+
}
27332773
// If the legacy encoding failed use LinearLayouts.
27342774
// Once LinearLayouts are more widely used, we can remove
27352775
// inferReshapeOpLegacyEncoding and simply use LLs.
@@ -2755,7 +2795,7 @@ struct TritonGPUInferLayoutInterface
27552795
SmallVector<int64_t> joinedShape(shape);
27562796
joinedShape.push_back(2);
27572797
auto parent = enc.getParent();
2758-
auto parentLL = toLinearLayout(joinedShape, parent);
2798+
auto parentLL = toLinearLayout(joinedShape, parent, {});
27592799

27602800
Attribute splitEnc;
27612801
auto result = inferSplitOpEncoding(parent, splitEnc, joinedShape, loc);
@@ -2791,7 +2831,7 @@ struct TritonGPUInferLayoutInterface
27912831
}
27922832

27932833
// Append dim to shape
2794-
auto ll = toLinearLayout(shape, srcEnc);
2834+
auto ll = toLinearLayout(shape, srcEnc, {});
27952835
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
27962836
dstShape.push_back(1);
27972837
ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape));
@@ -2847,7 +2887,7 @@ struct TritonGPUInferLayoutInterface
28472887
auto ctx = getContext();
28482888

28492889
// Split on last dim
2850-
auto ll = toLinearLayout(shape, srcEnc);
2890+
auto ll = toLinearLayout(shape, srcEnc, {});
28512891
auto newLl = LinearLayout::empty();
28522892
auto result =
28532893
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/false, axis, loc);
@@ -2916,7 +2956,7 @@ struct TritonGPUInferLayoutInterface
29162956
}
29172957
}
29182958

2919-
auto ll = toLinearLayout(shape, inEnc);
2959+
auto ll = toLinearLayout(shape, inEnc, {});
29202960
auto newLl = LinearLayout::empty();
29212961
auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc);
29222962
if (!result.succeeded())
@@ -3027,15 +3067,16 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
30273067
return "";
30283068

30293069
// This RankedTensorType is a MemDescType (?!)
3030-
LinearLayout ll = triton::gpu::toLinearLayout(type);
3070+
auto shape = type.getShape();
3071+
auto layout = type.getEncoding();
3072+
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout, shape);
30313073

30323074
StringAttr kOffset = StringAttr::get(type.getContext(), "offset");
30333075
StringAttr kBlock = StringAttr::get(type.getContext(), "block");
30343076
int64_t tensorSize = product(type.getShape());
30353077
auto enc = type.getEncoding();
30363078
unsigned numBlocks = getNumCTAs(enc);
30373079
int32_t blockSize = tensorSize / numBlocks;
3038-
auto shape = type.getShape();
30393080

30403081
// elementMapping is for the non-hw layout, offsetMapping for hw-layout
30413082
std::vector<std::string> elementMapping(tensorSize);
@@ -3448,8 +3489,8 @@ int triton::gpu::lookupThreadsPerWarp(OpBuilder &rewriter) {
34483489
bool triton::gpu::areLayoutsEquivalent(ArrayRef<int64_t> shape,
34493490
DistributedEncodingTrait lhs,
34503491
DistributedEncodingTrait rhs) {
3451-
auto lhsLL = triton::gpu::toLinearLayout(shape, lhs);
3452-
auto rhsLL = triton::gpu::toLinearLayout(shape, rhs);
3492+
auto lhsLL = triton::gpu::toLinearLayout(shape, lhs, {});
3493+
auto rhsLL = triton::gpu::toLinearLayout(shape, rhs, {});
34533494
return lhsLL == rhsLL;
34543495
}
34553496

0 commit comments

Comments
 (0)