Skip to content

Commit 55cea61

Browse files
authored
[LAYOUTS] Fix memdesc_subviews when we don't slice along the swizzling pattern (#7480)
The previous way of handling `memdesc_subviews` in the context of LinearLayouts was wrong. Consider a shmem layout of 1D of a shmem with `64` elements of the form ``` A = {offset = [1, 2, 5, 10, 32, 16], {dim0=64}} ``` If we take a `subview` of `A` that yields a `tesnor<32xtype>` , we should get the layout ``` A_sub = {offset = [1, 2, 5, 10, 0, 16], {dim0=32}} ``` which maps the full shared memory onto a tensor with `32` elements. When we take A_sub^{-1}B for `B` a distributed layout on a tensor of `32` elements to load this layout it will give us the correct mapping on the offsets, as expected. This PR fixes this at large by passing the initial shape of the shared memory, and then resizing in the creation of the `LinearLayout` the layout. This shows what we already realised at an IR level, that subviews effectively depend not only on the shape of the layout, but also the initial shape of the shared_memory. This PR also removes a number of hacks that worked around the issue above To do this, we generalise `lstsq` to compute a left inverse of A in when `A` is injective but its image is not a subset of that of `B`. The case where we split along a dimension that's within the swizzling pattern will be fixed in a follow-up PR.
1 parent 322cd5b commit 55cea61

File tree

17 files changed

+183
-114
lines changed

17 files changed

+183
-114
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#include <unordered_map>
1616

1717
// LinearLayoutCache Utils
18-
using CacheKey = std::tuple<std::vector<int64_t>, mlir::Attribute>;
18+
using CacheKey =
19+
std::tuple<std::vector<int64_t>, mlir::Attribute, std::vector<int64_t>>;
1920

2021
namespace llvm {
2122
template <typename T> size_t hash_value(const std::vector<T> &vec) {

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class MemDescType;
4747
// elemBitWidth is the bit width of one element in the layout. This is required
4848
// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e.
4949
// shared layouts with nvmma_shared layout) but is otherwise unused.
50-
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
50+
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
51+
ArrayRef<int64_t> allocationShape);
5152
LinearLayout toLinearLayout(RankedTensorType type);
5253
LinearLayout toLinearLayout(MemDescType type);
5354
LinearLayout toLinearLayout(TensorOrMemDesc type);

include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def TritonGPU_Dialect : Dialect {
2222
let extraClassDeclaration = [{
2323
void registerTypes();
2424

25-
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
25+
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout, ArrayRef<int64_t> allocationShape);
2626
LinearEncodingAttr toLinearEncoding(ArrayRef<int64_t> shape, Attribute layout);
2727

2828
static int getNumCTAs(ModuleOp mod);

include/triton/Tools/LinearLayout.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ class LinearLayout {
325325
bases;
326326

327327
llvm::MapVector<StringAttr, int32_t /*size*/> outDims;
328-
bool surjective = true;
328+
int32_t rank = 0;
329329

330330
public:
331331
using BasesT = decltype(bases);
@@ -425,10 +425,11 @@ class LinearLayout {
425425
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>> bases,
426426
ArrayRef<std::pair<StringAttr, int32_t>> outDims, bool requireSurjective);
427427

428-
bool isSurjective() const { return surjective; }
428+
bool isSurjective() const { return rank == getTotalOutDimSizeLog2(); }
429+
bool isInjective() const { return rank == getTotalInDimSizeLog2(); }
429430

430431
bool isInvertible() const {
431-
return surjective && getTotalInDimSize() == getTotalOutDimSize();
432+
return isSurjective() && getTotalInDimSize() == getTotalOutDimSize();
432433
}
433434

434435
const BasesT &getBases() const { return bases; }

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ namespace {
4040
LinearLayout getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
4141
LinearLayout regLayout,
4242
triton::gpu::SharedEncodingTrait dstEnc,
43-
int elemBitWidth) {
43+
int elemBitWidth,
44+
ArrayRef<int64_t> allocShape) {
4445
StringAttr kBlock = StringAttr::get(ctx, ("block"));
4546
int rank = shape.size();
4647

47-
LinearLayout sharedLayout = triton::gpu::toLinearLayout(shape, dstEnc);
48+
LinearLayout sharedLayout =
49+
triton::gpu::toLinearLayout(shape, dstEnc, allocShape);
4850
auto sharedOrder = triton::gpu::getOrder(dstEnc, shape);
4951

5052
// sharedLayout's in-dims are currently (offset, block). Reshape to
@@ -378,7 +380,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
378380
MLIRContext *ctx = rewriter.getContext();
379381
auto shape = type.getShape();
380382

381-
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
383+
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout, {});
382384

383385
StringAttr kRegister = str_attr("register");
384386
StringAttr kLane = str_attr("lane");
@@ -503,7 +505,7 @@ SmallVector<Value> getSmemVecAddrVec(
503505
sharedEnc)) {
504506
auto regToSharedSwizzledLayout =
505507
getRegToSharedLayout(ctx, shape, regLayout, swizzledSharedEnc,
506-
elemLlvmTy.getIntOrFloatBitWidth());
508+
elemLlvmTy.getIntOrFloatBitWidth(), allocShape);
507509
auto smemOrder = swizzledSharedEnc.getOrder();
508510

509511
auto swizzledIndicesVec =
@@ -659,9 +661,9 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
659661
bool isStore = !valsArray.empty();
660662
auto b = TritonLLVMOpBuilder(loc, rewriter);
661663

662-
auto emitCpAsync = [&](ConversionPatternRewriter &rewriter, Location loc,
663-
ArrayRef<Value> vals, Value shmemAddr, int idx,
664-
VectorType vecTy) -> SmallVector<Value> {
664+
auto emitLdSt = [&](ConversionPatternRewriter &rewriter, Location loc,
665+
ArrayRef<Value> vals, Value shmemAddr, int idx,
666+
VectorType vecTy) -> SmallVector<Value> {
665667
auto length = vecTy.getNumElements();
666668
if (isStore) {
667669
Value valsVec =
@@ -677,7 +679,7 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
677679
}
678680
};
679681
return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase, rewriter,
680-
targetInfo, {}, emitCpAsync);
682+
targetInfo, {}, emitLdSt);
681683
}
682684

683685
SmallVector<Value> lowerLdSt(
@@ -859,11 +861,13 @@ bool emitTransferBetweenRegistersAndShared(
859861
auto allocShape = sharedTy.getAllocShape();
860862
auto invertAllocSharedLayout = LinearLayout::empty();
861863
if (!paddedLayout) {
862-
// For now this is only needed for the cases where we have swizzling.
863-
invertAllocSharedLayout =
864-
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
865-
sharedTy.getEncoding())
866-
.pseudoinvert();
864+
// This is the legacy way of doing things that's much more ad-hoc
865+
// For generic shared layouts it may or may not be correct
866+
auto allocShape = sharedTy.getAllocShape();
867+
auto trimShape = allocShape.take_back(sharedTy.getRank());
868+
invertAllocSharedLayout = triton::gpu::toLinearLayout(
869+
trimShape, sharedTy.getEncoding(), trimShape)
870+
.pseudoinvert();
867871
}
868872

869873
int numElems = regToSharedLayout.getInDimSize(kRegister);
@@ -1473,7 +1477,7 @@ delinearize(RewriterBase &rewriter, Location loc,
14731477
triton::gpu::DistributedEncodingTrait layout,
14741478
ArrayRef<int64_t> shape, StringAttr dimName, Value linear) {
14751479
auto b = TritonLLVMOpBuilder(loc, rewriter);
1476-
auto ll = triton::gpu::toLinearLayout(shape, layout);
1480+
auto ll = triton::gpu::toLinearLayout(shape, layout, {});
14771481
auto linearLayout =
14781482
triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll);
14791483
assert(ll.hasInDim(dimName));

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -471,35 +471,24 @@ struct MemDescSubviewOpConversion
471471
// newBase = base + offset
472472
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
473473
llvmElemTy, rewriter);
474-
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
475474
SmallVector<Value> opOffsetVals = op.getOffsets();
476475
// We assume we always create a subview of the last dimensions
477-
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
478-
smemStrides.end());
479476
// Compute total offset
480-
SmallVector<Value> offsetVals;
481-
auto destRank = op.getResult().getType().getRank();
482-
auto rankReduced = srcTy.getRank() - destRank;
483-
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
484-
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
485-
}
477+
auto rankReduced = srcTy.getRank() - destTy.getRank();
486478

487479
Value offset;
488480
if (rankReduced || (destTy.getRank() == 1 && destTy.getDimSize(0) == 1)) {
481+
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
482+
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
483+
smemStrides.end());
489484
// We are splitting the pipelining dimension which may not be a power of 2
490485
// so we can't use LinearLayouts
491486
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
492487
} else {
493488
auto dimNames = standardOutDimNames(ctx, opOffsetVals.size());
494489
SmallVector<std::pair<StringAttr, Value>> logicalOffsets;
495-
// This assumes the subviews are additive, in the sense that we can
496-
// compute the offset of one and an add it to the offset of the previous
497-
// one we computed. We check for this in the verifier.
498-
for (int i = 0; i < rankReduced; i++) {
499-
logicalOffsets.push_back({dimNames[i], b.i32_val(0)});
500-
}
501-
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
502-
logicalOffsets.push_back({dimNames[i], offsetVals[i - rankReduced]});
490+
for (auto [dim, offset] : llvm::zip(dimNames, opOffsetVals)) {
491+
logicalOffsets.push_back({dim, offset});
503492
}
504493
auto ll = toLinearLayout(srcTy);
505494
// Checked in the verifier.
@@ -517,6 +506,11 @@ struct MemDescSubviewOpConversion
517506
offset = b.add(offset, padOffset);
518507
}
519508

509+
SmallVector<Value> offsetVals;
510+
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
511+
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
512+
}
513+
520514
auto base = smemObj.getBase();
521515
auto elemPtrTy = base.getType();
522516
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,14 @@ namespace gpu {
3939

4040
LinearEncodingAttr TritonGPUDialect::toLinearEncoding(ArrayRef<int64_t> shape,
4141
Attribute layout) {
42-
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout};
42+
// LinearEncoding is a DistributedLayout
43+
std::vector<int64_t> allocationShape;
44+
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout,
45+
allocationShape};
4346
if (auto result = leCache.get(key)) {
4447
return *result;
4548
}
46-
auto linearLayout = toLinearLayout(shape, layout);
49+
auto linearLayout = toLinearLayout(shape, layout, {});
4750
auto linearEncoding =
4851
LinearEncodingAttr::get(layout.getContext(), std::move(linearLayout));
4952
leCache.set(key, linearEncoding);
@@ -2369,7 +2372,7 @@ struct TritonGPUInferLayoutInterface
23692372
return success();
23702373
}
23712374

2372-
auto ll = toLinearLayout(shape, operandEncoding);
2375+
auto ll = toLinearLayout(shape, operandEncoding, {});
23732376
auto transposedLl = transposeLinearLayout(ll, order);
23742377
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));
23752378
return success();
@@ -2491,8 +2494,9 @@ struct TritonGPUInferLayoutInterface
24912494
CTALayout);
24922495
// Big guns, check linear layouts are equivalent
24932496
// We disallow reshaping memdesc_subviews in the verifier
2494-
auto srcLL = toLinearLayout(srcShape, srcEnc);
2495-
auto dstLL = toLinearLayout(dstShape, dstEnc);
2497+
// We disallow reshaping memdesc_subviews in the verifier
2498+
auto srcLL = toLinearLayout(srcShape, srcEnc, srcShape);
2499+
auto dstLL = toLinearLayout(dstShape, dstEnc, dstShape);
24962500
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
24972501
return failure();
24982502
}
@@ -2774,7 +2778,7 @@ struct TritonGPUInferLayoutInterface
27742778
SmallVector<int64_t> joinedShape(shape);
27752779
joinedShape.push_back(2);
27762780
auto parent = enc.getParent();
2777-
auto parentLL = toLinearLayout(joinedShape, parent);
2781+
auto parentLL = toLinearLayout(joinedShape, parent, {});
27782782

27792783
Attribute splitEnc;
27802784
auto result = inferSplitOpEncoding(parent, splitEnc, joinedShape, loc);
@@ -2810,7 +2814,7 @@ struct TritonGPUInferLayoutInterface
28102814
}
28112815

28122816
// Append dim to shape
2813-
auto ll = toLinearLayout(shape, srcEnc);
2817+
auto ll = toLinearLayout(shape, srcEnc, {});
28142818
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
28152819
dstShape.push_back(1);
28162820
ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape));
@@ -2866,7 +2870,7 @@ struct TritonGPUInferLayoutInterface
28662870
auto ctx = getContext();
28672871

28682872
// Split on last dim
2869-
auto ll = toLinearLayout(shape, srcEnc);
2873+
auto ll = toLinearLayout(shape, srcEnc, {});
28702874
auto newLl = LinearLayout::empty();
28712875
auto result =
28722876
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/false, axis, loc);
@@ -2935,7 +2939,7 @@ struct TritonGPUInferLayoutInterface
29352939
}
29362940
}
29372941

2938-
auto ll = toLinearLayout(shape, inEnc);
2942+
auto ll = toLinearLayout(shape, inEnc, {});
29392943
auto newLl = LinearLayout::empty();
29402944
auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc);
29412945
if (!result.succeeded())
@@ -3042,15 +3046,16 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
30423046
return "";
30433047

30443048
// This RankedTensorType is a MemDescType (?!)
3045-
LinearLayout ll = triton::gpu::toLinearLayout(type);
3049+
auto shape = type.getShape();
3050+
auto layout = type.getEncoding();
3051+
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout, shape);
30463052

30473053
StringAttr kOffset = StringAttr::get(type.getContext(), "offset");
30483054
StringAttr kBlock = StringAttr::get(type.getContext(), "block");
30493055
int64_t tensorSize = product(type.getShape());
30503056
auto enc = type.getEncoding();
30513057
unsigned numBlocks = getNumCTAs(enc);
30523058
int32_t blockSize = tensorSize / numBlocks;
3053-
auto shape = type.getShape();
30543059

30553060
// elementMapping is for the non-hw layout, offsetMapping for hw-layout
30563061
std::vector<std::string> elementMapping(tensorSize);
@@ -3463,8 +3468,8 @@ int triton::gpu::lookupThreadsPerWarp(OpBuilder &rewriter) {
34633468
bool triton::gpu::areLayoutsEquivalent(ArrayRef<int64_t> shape,
34643469
DistributedEncodingTrait lhs,
34653470
DistributedEncodingTrait rhs) {
3466-
auto lhsLL = triton::gpu::toLinearLayout(shape, lhs);
3467-
auto rhsLL = triton::gpu::toLinearLayout(shape, rhs);
3471+
auto lhsLL = triton::gpu::toLinearLayout(shape, lhs, {});
3472+
auto rhsLL = triton::gpu::toLinearLayout(shape, rhs, {});
34683473
return lhsLL == rhsLL;
34693474
}
34703475

0 commit comments

Comments
 (0)