Skip to content

Commit 9275820

Browse files
Merge OpenAI Triton commit 91d58f5 (#4753)
This PR change the Triton base from 4b36a8a to 91d58f5 (Jul 11). Pass rate: 98.46%
2 parents ada4e6e + f628bc9 commit 9275820

File tree

30 files changed

+180
-179
lines changed

30 files changed

+180
-179
lines changed

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ReduceOpHelper {
2727
explicit ReduceOpHelper(triton::ReduceOp op)
2828
: op(op.getOperation()), axis(op.getAxis()) {
2929
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
30+
srcTy = firstTy;
3031
srcShape = firstTy.getShape();
3132
srcEncoding = firstTy.getEncoding();
3233
srcElementTypes = op.getElementTypes();
@@ -68,6 +69,7 @@ class ReduceOpHelper {
6869

6970
private:
7071
triton::ReduceOp op;
72+
RankedTensorType srcTy;
7173
ArrayRef<int64_t> srcShape;
7274
Attribute srcEncoding;
7375
SmallVector<Type> srcElementTypes;
@@ -80,7 +82,7 @@ class ScanLoweringHelper {
8082
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
8183
srcShape = firstTy.getShape();
8284
legacyEncoding = firstTy.getEncoding();
83-
srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape);
85+
srcEncoding = triton::gpu::toLinearEncoding(firstTy);
8486
srcElementTypes = op.getElementTypes();
8587
// The codegen does not support different element/thread/warp order so
8688
// we choose one a priori. We choose that of the blocked encoding.
@@ -166,6 +168,8 @@ class GatherLoweringHelper {
166168

167169
private:
168170
triton::GatherOp gatherOp;
171+
RankedTensorType srcTy;
172+
RankedTensorType dstTy;
169173
};
170174

171175
// This struct represents a decomposed layout conversion within a warp into

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
9393

9494
// Convert a distributed layout to a linear encoding
9595
LinearEncodingAttr toLinearEncoding(RankedTensorType type);
96-
LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape);
96+
LinearEncodingAttr toLinearEncoding(DistributedEncodingTrait layout,
97+
ArrayRef<int64_t> shape);
9798

9899
unsigned getTotalElemsPerThread(Type type);
99100

@@ -274,14 +275,13 @@ llvm::SmallVector<unsigned>
274275
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);
275276

276277
// Return true if the two layouts represent the exact same mapping.
277-
bool areLayoutsEquivalent(ArrayRef<int64_t> shape, Attribute lhs,
278-
Attribute rhs);
278+
bool areLayoutsEquivalent(ArrayRef<int64_t> shape, DistributedEncodingTrait lhs,
279+
DistributedEncodingTrait rhs);
279280

280281
// Return true if the innermost numElems are contiguous.
281282
bool isInnermostContiguous(MemDescType type, unsigned numElems);
282283

283-
LinearLayout inferReshapeLinearLayout(ArrayRef<int64_t> srcShape,
284-
Attribute srcEnc,
284+
LinearLayout inferReshapeLinearLayout(TensorOrMemDesc srcTy,
285285
ArrayRef<int64_t> dstShape);
286286

287287
// Verify the types of operations that operate on memory.

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class SwizzledSharedEncodingAttr;
1717
class NVMMASharedEncodingAttr;
1818
class AMDRotatingSharedEncodingAttr;
1919
class AMDMfmaEncodingAttr;
20+
class TensorOrMemDesc;
21+
class MemDescType;
2022

2123
// - BlockedEncodingAttrs have the following input dimensions.
2224
//
@@ -45,9 +47,10 @@ class AMDMfmaEncodingAttr;
4547
// elemBitWidth is the bit width of one element in the layout. This is required
4648
// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e.
4749
// shared layouts with nvmma_shared layout) but is otherwise unused.
48-
//
49-
// Returns std::nullopt if the given layout can't be converted to an LL.
5050
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
51+
LinearLayout toLinearLayout(RankedTensorType type);
52+
LinearLayout toLinearLayout(MemDescType type);
53+
LinearLayout toLinearLayout(TensorOrMemDesc type);
5154

5255
// Convert the shared encoding of a tensor with `nvmma_shared` layout to a
5356
// LinearLayout that maps from a linear shared memory offset to tensor index.

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ static unsigned getBitwidth(RankedTensorType ty) {
4242
static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4343
RankedTensorType dstTy) {
4444
auto *ctx = srcTy.getContext();
45-
auto srcLayout = gpu::toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
46-
auto dstLayout = gpu::toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
45+
auto srcLayout = gpu::toLinearLayout(srcTy);
46+
auto dstLayout = gpu::toLinearLayout(dstTy);
4747
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
4848
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
4949
auto bitwidth = getBitwidth(srcTy);
@@ -109,8 +109,8 @@ getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
109109
Attribute srcLayout = srcTy.getEncoding();
110110
Attribute dstLayout = dstTy.getEncoding();
111111

112-
auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape());
113-
auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape());
112+
auto srcLinAttr = gpu::toLinearEncoding(srcTy);
113+
auto dstLinAttr = gpu::toLinearEncoding(dstTy);
114114
auto inOrd = srcLinAttr.getOrder();
115115
auto outOrd = dstLinAttr.getOrder();
116116

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,8 +1232,7 @@ unsigned ModuleAxisInfoAnalysis::getContiguity(Value offsetsValue,
12321232
// the analysis to one dimension. We should determine contiguity on the
12331233
// flattenOuts() layout
12341234
auto tensorTy = cast<RankedTensorType>(offsetsValue.getType());
1235-
auto linAttr =
1236-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1235+
auto linAttr = gpu::toLinearEncoding(tensorTy);
12371236
auto order = linAttr.getOrder();
12381237
unsigned align = getAlignment(offsetsValue, elementBitWidth);
12391238

@@ -1266,8 +1265,7 @@ unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue,
12661265
auto *axisInfo = getAxisInfo(offsetsValue);
12671266
if (!axisInfo)
12681267
return 1;
1269-
auto linAttr =
1270-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1268+
auto linAttr = gpu::toLinearEncoding(tensorTy);
12711269
auto order = linAttr.getOrder();
12721270
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12731271
auto maxContig = axisInfo->getContiguity(order[0]);
@@ -1295,8 +1293,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12951293
auto *axisInfo = getAxisInfo(mask);
12961294
if (!axisInfo)
12971295
return 1;
1298-
auto linAttr =
1299-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1296+
auto linAttr = gpu::toLinearEncoding(tensorTy);
13001297
auto maskOrder = linAttr.getOrder();
13011298
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
13021299
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "

lib/Analysis/Utility.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using namespace triton;
2424
using namespace triton::gpu;
2525

2626
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
27-
auto order = toLinearEncoding(srcEncoding, srcShape).getOrder();
27+
auto order = toLinearEncoding(srcTy).getOrder();
2828
auto it = std::find(order.begin(), order.end(), axis);
2929
// delete the axis from order
3030
order.erase(it);
@@ -37,7 +37,7 @@ SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
3737
// reduction axis within the warp.
3838
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
3939
auto *ctx = srcEncoding.getContext();
40-
auto linearLayout = toLinearLayout(srcShape, srcEncoding);
40+
auto linearLayout = toLinearLayout(srcTy);
4141
auto kLane = mlir::StringAttr::get(ctx, "lane");
4242
const auto &bases = linearLayout.getBases();
4343
const auto &lanes = bases.find(kLane)->second;
@@ -576,10 +576,8 @@ bool GatherLoweringHelper::isWarpLocal() {
576576
// source and index tensors, all the elements are owned by the same warp.
577577
RankedTensorType srcType = gatherOp.getSrc().getType();
578578
RankedTensorType idxType = gatherOp.getIndices().getType();
579-
LinearLayout srcLayout =
580-
toLinearLayout(srcType.getShape(), srcType.getEncoding());
581-
LinearLayout idxLayout =
582-
toLinearLayout(idxType.getShape(), idxType.getEncoding());
579+
LinearLayout srcLayout = toLinearLayout(srcType);
580+
LinearLayout idxLayout = toLinearLayout(idxType);
583581

584582
Builder b(gatherOp.getContext());
585583
StringAttr kBlock = b.getStringAttr("block");
@@ -766,10 +764,8 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
766764
LinearLayout minimalCvtLayout(Type srcTy_, Type dstTy_) {
767765
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(srcTy_);
768766
auto dstTy = cast<triton::gpu::TensorOrMemDesc>(dstTy_);
769-
LinearLayout srcLayout =
770-
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
771-
LinearLayout dstLayout =
772-
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
767+
LinearLayout srcLayout = toLinearLayout(srcTy);
768+
LinearLayout dstLayout = toLinearLayout(dstTy);
773769
auto sDims = to_vector(srcLayout.getInDimNames());
774770
auto dDims = to_vector(dstLayout.getInDimNames());
775771
SmallVector<StringAttr> dims;

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
4343
auto dstTy = op.getType();
4444

4545
LinearLayout conversion = minimalCvtLayout(srcTy, dstTy);
46-
LinearLayout srcLayout =
47-
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
48-
LinearLayout dstLayout =
49-
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
46+
LinearLayout srcLayout = toLinearLayout(srcTy);
47+
LinearLayout dstLayout = toLinearLayout(dstTy);
5048

5149
StringAttr kBlock = str_attr("block");
5250
StringAttr kWarp = str_attr("warp");
@@ -246,8 +244,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
246244

247245
// Remove the kBlock dimension from the layout as it's the identity in the
248246
// cvt
249-
auto srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
250-
auto dstLayout = toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
247+
auto srcLayout = toLinearLayout(srcTy);
248+
auto dstLayout = toLinearLayout(dstTy);
251249
auto kReg = str_attr("register");
252250
auto kLane = str_attr("lane");
253251
auto kWarp = str_attr("warp");

lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,8 @@ void GatherOpConversion::emitWarpLocalGather(
209209
}
210210

211211
// Compute the src and idx layouts.
212-
LinearLayout srcLayout =
213-
toLinearLayout(srcType.getShape(), srcType.getEncoding());
214-
LinearLayout idxLayout =
215-
toLinearLayout(idxType.getShape(), idxType.getEncoding());
212+
LinearLayout srcLayout = toLinearLayout(srcType);
213+
LinearLayout idxLayout = toLinearLayout(idxType);
216214

217215
// Let `ll_src` be the source layout and `ll_idx` be the index layout.
218216
// Let `src_col` be a tuple of dimensions except the gather dimension,

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
3939
auto regTy = cast<RankedTensorType>(regVal.getType());
4040
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
4141

42-
auto regLayout = toLinearLayout(regTy.getShape(), regTy.getEncoding());
43-
auto sharedLayout =
44-
toLinearLayout(memDescTy.getShape(), memDescTy.getEncoding());
42+
auto regLayout = toLinearLayout(regTy);
43+
auto sharedLayout = toLinearLayout(memDescTy);
4544
auto cvt = regLayout.invertAndCompose(sharedLayout);
4645

4746
auto kBlock = str_attr("block");
@@ -193,9 +192,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
193192
return success();
194193
}
195194

196-
auto regLayout = toLinearLayout(regTy.getShape(), regTy.getEncoding());
197-
auto sharedLayout =
198-
toLinearLayout(memDescTy.getShape(), memDescTy.getEncoding());
195+
auto regLayout = toLinearLayout(regTy);
196+
auto sharedLayout = toLinearLayout(memDescTy);
199197
auto cvt = regLayout.invertAndCompose(sharedLayout);
200198
auto kBlock = str_attr("block");
201199
// NYI. We would need to emit a map.shared::cluster instruction.

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ bool emitTransferBetweenRegistersAndShared(
832832
regToSharedLayout =
833833
regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}});
834834
} else {
835-
auto sharedLL = triton::gpu::toLinearLayout(shape, sharedTy.getEncoding());
835+
auto sharedLL = triton::gpu::toLinearLayout(sharedTy);
836836
regToSharedLayout = regLayout.invertAndCompose(sharedLL);
837837
}
838838

@@ -908,8 +908,7 @@ bool emitTransferBetweenRegistersAndShared(
908908
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
909909
const TargetInfoBase &target,
910910
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
911-
auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(),
912-
registerTy.getEncoding());
911+
auto regLayout = triton::gpu::toLinearLayout(registerTy);
913912
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
914913
return emitTransferBetweenRegistersAndShared(
915914
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
@@ -1131,8 +1130,7 @@ llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type) {
11311130
if (!tensorTy) {
11321131
return getAllFreeVarMasks(ctx);
11331132
}
1134-
auto ll =
1135-
triton::gpu::toLinearLayout(tensorTy.getShape(), tensorTy.getEncoding());
1133+
auto ll = triton::gpu::toLinearLayout(tensorTy);
11361134
return ll.getFreeVariableMasks();
11371135
}
11381136

@@ -1142,7 +1140,7 @@ SmallVector<SmallVector<unsigned>> emitOffsetForLayout(Attribute layout,
11421140
auto shape = type.getShape();
11431141
unsigned rank = shape.size();
11441142

1145-
auto ll = triton::gpu::toLinearLayout(shape, layout);
1143+
auto ll = triton::gpu::toLinearLayout(type);
11461144

11471145
StringAttr kRegister = str_attr("register");
11481146
StringAttr kLane = str_attr("lane");

0 commit comments

Comments
 (0)