Skip to content

Commit 91d58f5

Browse files
authored
[NFC] Make toLinearLayout take a RankedTensorType or MemDescType (#7440)
This PR is an NFC that mostly improves readability and usability. There are a few places where we bend ourselves backward to pass the full type to `toLinearLayout` tho. This is because this PR is in preparation for the next PR where we fix `memdesc_subview` at large. To be able to do this, we need the `allocationShape` from `MemDescType` to create the associated `LinearLayout`, so this was the cleanest way to make sure we don't introduce any subtle error.
1 parent 4b36a8a commit 91d58f5

File tree

29 files changed

+178
-174
lines changed

29 files changed

+178
-174
lines changed

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ReduceOpHelper {
2727
explicit ReduceOpHelper(triton::ReduceOp op)
2828
: op(op.getOperation()), axis(op.getAxis()) {
2929
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
30+
srcTy = firstTy;
3031
srcShape = firstTy.getShape();
3132
srcEncoding = firstTy.getEncoding();
3233
srcElementTypes = op.getElementTypes();
@@ -68,6 +69,7 @@ class ReduceOpHelper {
6869

6970
private:
7071
triton::ReduceOp op;
72+
RankedTensorType srcTy;
7173
ArrayRef<int64_t> srcShape;
7274
Attribute srcEncoding;
7375
SmallVector<Type> srcElementTypes;
@@ -80,7 +82,7 @@ class ScanLoweringHelper {
8082
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
8183
srcShape = firstTy.getShape();
8284
legacyEncoding = firstTy.getEncoding();
83-
srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape);
85+
srcEncoding = triton::gpu::toLinearEncoding(firstTy);
8486
srcElementTypes = op.getElementTypes();
8587
// The codegen does not support different element/thread/warp order so
8688
// we choose one a priori. We choose that of the blocked encoding.
@@ -166,6 +168,8 @@ class GatherLoweringHelper {
166168

167169
private:
168170
triton::GatherOp gatherOp;
171+
RankedTensorType srcTy;
172+
RankedTensorType dstTy;
169173
};
170174

171175
// This struct represents a decomposed layout conversion within a warp into

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
9393

9494
// Convert a distributed layout to a linear encoding
9595
LinearEncodingAttr toLinearEncoding(RankedTensorType type);
96-
LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape);
96+
LinearEncodingAttr toLinearEncoding(DistributedEncodingTrait layout,
97+
ArrayRef<int64_t> shape);
9798

9899
unsigned getTotalElemsPerThread(Type type);
99100

@@ -274,14 +275,13 @@ llvm::SmallVector<unsigned>
274275
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);
275276

276277
// Return true if the two layouts represent the exact same mapping.
277-
bool areLayoutsEquivalent(ArrayRef<int64_t> shape, Attribute lhs,
278-
Attribute rhs);
278+
bool areLayoutsEquivalent(ArrayRef<int64_t> shape, DistributedEncodingTrait lhs,
279+
DistributedEncodingTrait rhs);
279280

280281
// Return true if the innermost numElems are contiguous.
281282
bool isInnermostContiguous(MemDescType type, unsigned numElems);
282283

283-
LinearLayout inferReshapeLinearLayout(ArrayRef<int64_t> srcShape,
284-
Attribute srcEnc,
284+
LinearLayout inferReshapeLinearLayout(TensorOrMemDesc srcTy,
285285
ArrayRef<int64_t> dstShape);
286286

287287
// Verify the types of operations that operate on memory.

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class SwizzledSharedEncodingAttr;
1717
class NVMMASharedEncodingAttr;
1818
class AMDRotatingSharedEncodingAttr;
1919
class AMDMfmaEncodingAttr;
20+
class TensorOrMemDesc;
21+
class MemDescType;
2022

2123
// - BlockedEncodingAttrs have the following input dimensions.
2224
//
@@ -45,9 +47,10 @@ class AMDMfmaEncodingAttr;
4547
// elemBitWidth is the bit width of one element in the layout. This is required
4648
// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e.
4749
// shared layouts with nvmma_shared layout) but is otherwise unused.
48-
//
49-
// Returns std::nullopt if the given layout can't be converted to an LL.
5050
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
51+
LinearLayout toLinearLayout(RankedTensorType type);
52+
LinearLayout toLinearLayout(MemDescType type);
53+
LinearLayout toLinearLayout(TensorOrMemDesc type);
5154

5255
// Convert the shared encoding of a tensor with `nvmma_shared` layout to a
5356
// LinearLayout that maps from a linear shared memory offset to tensor index.

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ static unsigned getBitwidth(RankedTensorType ty) {
4242
static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4343
RankedTensorType dstTy) {
4444
auto *ctx = srcTy.getContext();
45-
auto srcLayout = gpu::toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
46-
auto dstLayout = gpu::toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
45+
auto srcLayout = gpu::toLinearLayout(srcTy);
46+
auto dstLayout = gpu::toLinearLayout(dstTy);
4747
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
4848
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
4949
auto bitwidth = getBitwidth(srcTy);
@@ -109,8 +109,8 @@ getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
109109
Attribute srcLayout = srcTy.getEncoding();
110110
Attribute dstLayout = dstTy.getEncoding();
111111

112-
auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape());
113-
auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape());
112+
auto srcLinAttr = gpu::toLinearEncoding(srcTy);
113+
auto dstLinAttr = gpu::toLinearEncoding(dstTy);
114114
auto inOrd = srcLinAttr.getOrder();
115115
auto outOrd = dstLinAttr.getOrder();
116116

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,8 +1232,7 @@ unsigned ModuleAxisInfoAnalysis::getContiguity(Value offsetsValue,
12321232
// the analysis to one dimension. We should determine contiguity on the
12331233
// flattenOuts() layout
12341234
auto tensorTy = cast<RankedTensorType>(offsetsValue.getType());
1235-
auto linAttr =
1236-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1235+
auto linAttr = gpu::toLinearEncoding(tensorTy);
12371236
auto order = linAttr.getOrder();
12381237
unsigned align = getAlignment(offsetsValue, elementBitWidth);
12391238

@@ -1266,8 +1265,7 @@ unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue,
12661265
auto *axisInfo = getAxisInfo(offsetsValue);
12671266
if (!axisInfo)
12681267
return 1;
1269-
auto linAttr =
1270-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1268+
auto linAttr = gpu::toLinearEncoding(tensorTy);
12711269
auto order = linAttr.getOrder();
12721270
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12731271
auto maxContig = axisInfo->getContiguity(order[0]);
@@ -1295,8 +1293,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12951293
auto *axisInfo = getAxisInfo(mask);
12961294
if (!axisInfo)
12971295
return 1;
1298-
auto linAttr =
1299-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1296+
auto linAttr = gpu::toLinearEncoding(tensorTy);
13001297
auto maskOrder = linAttr.getOrder();
13011298
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
13021299
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "

lib/Analysis/Utility.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using namespace triton;
2323
using namespace triton::gpu;
2424

2525
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
26-
auto order = toLinearEncoding(srcEncoding, srcShape).getOrder();
26+
auto order = toLinearEncoding(srcTy).getOrder();
2727
auto it = std::find(order.begin(), order.end(), axis);
2828
// delete the axis from order
2929
order.erase(it);
@@ -36,7 +36,7 @@ SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
3636
// reduction axis within the warp.
3737
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
3838
auto *ctx = srcEncoding.getContext();
39-
auto linearLayout = toLinearLayout(srcShape, srcEncoding);
39+
auto linearLayout = toLinearLayout(srcTy);
4040
auto kLane = mlir::StringAttr::get(ctx, "lane");
4141
const auto &bases = linearLayout.getBases();
4242
const auto &lanes = bases.find(kLane)->second;
@@ -570,10 +570,8 @@ bool GatherLoweringHelper::isWarpLocal() {
570570
// source and index tensors, all the elements are owned by the same warp.
571571
RankedTensorType srcType = gatherOp.getSrc().getType();
572572
RankedTensorType idxType = gatherOp.getIndices().getType();
573-
LinearLayout srcLayout =
574-
toLinearLayout(srcType.getShape(), srcType.getEncoding());
575-
LinearLayout idxLayout =
576-
toLinearLayout(idxType.getShape(), idxType.getEncoding());
573+
LinearLayout srcLayout = toLinearLayout(srcType);
574+
LinearLayout idxLayout = toLinearLayout(idxType);
577575

578576
Builder b(gatherOp.getContext());
579577
StringAttr kBlock = b.getStringAttr("block");
@@ -760,10 +758,8 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
760758
LinearLayout minimalCvtLayout(Type srcTy_, Type dstTy_) {
761759
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(srcTy_);
762760
auto dstTy = cast<triton::gpu::TensorOrMemDesc>(dstTy_);
763-
LinearLayout srcLayout =
764-
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
765-
LinearLayout dstLayout =
766-
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
761+
LinearLayout srcLayout = toLinearLayout(srcTy);
762+
LinearLayout dstLayout = toLinearLayout(dstTy);
767763
auto sDims = to_vector(srcLayout.getInDimNames());
768764
auto dDims = to_vector(dstLayout.getInDimNames());
769765
SmallVector<StringAttr> dims;

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
4343
auto dstTy = op.getType();
4444

4545
LinearLayout conversion = minimalCvtLayout(srcTy, dstTy);
46-
LinearLayout srcLayout =
47-
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
48-
LinearLayout dstLayout =
49-
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
46+
LinearLayout srcLayout = toLinearLayout(srcTy);
47+
LinearLayout dstLayout = toLinearLayout(dstTy);
5048

5149
StringAttr kBlock = str_attr("block");
5250
StringAttr kWarp = str_attr("warp");
@@ -246,8 +244,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
246244

247245
// Remove the kBlock dimension from the layout as it's the identity in the
248246
// cvt
249-
auto srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
250-
auto dstLayout = toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
247+
auto srcLayout = toLinearLayout(srcTy);
248+
auto dstLayout = toLinearLayout(dstTy);
251249
auto kReg = str_attr("register");
252250
auto kLane = str_attr("lane");
253251
auto kWarp = str_attr("warp");

lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,8 @@ void GatherOpConversion::emitWarpLocalGather(
209209
}
210210

211211
// Compute the src and idx layouts.
212-
LinearLayout srcLayout =
213-
toLinearLayout(srcType.getShape(), srcType.getEncoding());
214-
LinearLayout idxLayout =
215-
toLinearLayout(idxType.getShape(), idxType.getEncoding());
212+
LinearLayout srcLayout = toLinearLayout(srcType);
213+
LinearLayout idxLayout = toLinearLayout(idxType);
216214

217215
// Let `ll_src` be the source layout and `ll_idx` be the index layout.
218216
// Let `src_col` be a tuple of dimensions except the gather dimension,

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
3939
auto regTy = cast<RankedTensorType>(regVal.getType());
4040
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
4141

42-
auto regLayout = toLinearLayout(regTy.getShape(), regTy.getEncoding());
43-
auto sharedLayout =
44-
toLinearLayout(memDescTy.getShape(), memDescTy.getEncoding());
42+
auto regLayout = toLinearLayout(regTy);
43+
auto sharedLayout = toLinearLayout(memDescTy);
4544
auto cvt = regLayout.invertAndCompose(sharedLayout);
4645

4746
auto kBlock = str_attr("block");
@@ -193,9 +192,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
193192
return success();
194193
}
195194

196-
auto regLayout = toLinearLayout(regTy.getShape(), regTy.getEncoding());
197-
auto sharedLayout =
198-
toLinearLayout(memDescTy.getShape(), memDescTy.getEncoding());
195+
auto regLayout = toLinearLayout(regTy);
196+
auto sharedLayout = toLinearLayout(memDescTy);
199197
auto cvt = regLayout.invertAndCompose(sharedLayout);
200198
auto kBlock = str_attr("block");
201199
// NYI. We would need to emit a map.shared::cluster instruction.

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ bool emitTransferBetweenRegistersAndShared(
811811
regToSharedLayout =
812812
regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}});
813813
} else {
814-
auto sharedLL = triton::gpu::toLinearLayout(shape, sharedTy.getEncoding());
814+
auto sharedLL = triton::gpu::toLinearLayout(sharedTy);
815815
regToSharedLayout = regLayout.invertAndCompose(sharedLL);
816816
}
817817

@@ -887,8 +887,7 @@ bool emitTransferBetweenRegistersAndShared(
887887
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
888888
const TargetInfoBase &target,
889889
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
890-
auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(),
891-
registerTy.getEncoding());
890+
auto regLayout = triton::gpu::toLinearLayout(registerTy);
892891
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
893892
return emitTransferBetweenRegistersAndShared(
894893
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
@@ -1110,8 +1109,7 @@ llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type) {
11101109
if (!tensorTy) {
11111110
return getAllFreeVarMasks(ctx);
11121111
}
1113-
auto ll =
1114-
triton::gpu::toLinearLayout(tensorTy.getShape(), tensorTy.getEncoding());
1112+
auto ll = triton::gpu::toLinearLayout(tensorTy);
11151113
return ll.getFreeVariableMasks();
11161114
}
11171115

@@ -1121,7 +1119,7 @@ SmallVector<SmallVector<unsigned>> emitOffsetForLayout(Attribute layout,
11211119
auto shape = type.getShape();
11221120
unsigned rank = shape.size();
11231121

1124-
auto ll = triton::gpu::toLinearLayout(shape, layout);
1122+
auto ll = triton::gpu::toLinearLayout(type);
11251123

11261124
StringAttr kRegister = str_attr("register");
11271125
StringAttr kLane = str_attr("lane");

0 commit comments

Comments
 (0)