Skip to content

Commit 4b36a8a

Browse files
authored
[NFC] Use RankedTensorType's clone and cloneWithEncoding member functions (#7464)
Credit to Jeff for pointing out these exist.
1 parent 6f99791 commit 4b36a8a

File tree

25 files changed

+80
-158
lines changed

25 files changed

+80
-158
lines changed

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
108108
if (auto intTy = llvm::dyn_cast<IntegerType>(blockType.getElementType())) {
109109
auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned;
110110
auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem);
111-
blockType = RankedTensorType::get(blockType.getShape(), elemTy);
111+
blockType = blockType.clone(elemTy);
112112
}
113113
return Base::get($_ctxt, blockType);
114114
}]>,
@@ -119,7 +119,7 @@ def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
119119
if (auto intTy = llvm::dyn_cast<IntegerType>(resTy.getElementType())) {
120120
auto width = resTy.getElementTypeBitWidth();
121121
auto signlessTy = IntegerType::get(getContext(), width);
122-
resTy = RankedTensorType::get(resTy.getShape(), signlessTy);
122+
resTy = resTy.clone(signlessTy);
123123
}
124124
return resTy;
125125
}

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
724724
return false;
725725
}
726726
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
727-
auto parentTy = RankedTensorType::get(
728-
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
727+
auto parentTy = srcTy.cloneWithEncoding(dotOperandLayout.getParent());
729728
auto ans = mmaLayout.getVersionMajor() == 3 &&
730729
dotOperandLayout.getOpIdx() == 0 &&
731730
mmaLayout.getWarpsPerCTA()[1] == 1 &&

lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ RankedTensorType getTMEMTensorLayout(const TypeConverter *tc,
3131
encoding = ttng::getTmemCompatibleLayout(
3232
tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps);
3333
}
34-
return RankedTensorType::get(type.getShape(), type.getElementType(),
35-
encoding);
34+
return type.cloneWithEncoding(encoding);
3635
}
3736

3837
struct TMEMLoadOpPattern : public OpConversionPattern<ttng::TMEMLoadOp> {

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
3434
triton::gpu::BlockedEncodingAttr encoding =
3535
getDefaultBlockedEncoding(this->context, shape, this->numWarps,
3636
this->threadsPerWarp, this->numCTAs);
37-
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
37+
return tensorType.cloneWithEncoding(encoding);
3838
});
3939

4040
// Add encoding for tensor pointer
@@ -150,8 +150,7 @@ static RankedTensorType getNewIndicesType(RankedTensorType type,
150150
if (enc == newEncoding)
151151
return {};
152152

153-
return RankedTensorType::get(type.getShape(), type.getElementType(),
154-
newEncoding);
153+
return type.cloneWithEncoding(newEncoding);
155154
}
156155

157156
// Function for converting any gather or scatter op that requires a specific

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ struct TritonExpandDimsPattern
172172
// convert operand to slice of return type
173173
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
174174
getContext(), op.getAxis(), retEncoding);
175-
RankedTensorType newArgType = RankedTensorType::get(
176-
argType.getShape(), argType.getElementType(), newArgEncoding);
175+
RankedTensorType newArgType = argType.cloneWithEncoding(newArgEncoding);
177176
// construct new op
178177
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
179178
op.getLoc(), newArgType, adaptor.getSrc());
@@ -238,8 +237,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
238237
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
239238
getContext(), origShape, retSizePerThread, retOrder, numWarps,
240239
threadsPerWarp, numCTAs);
241-
RankedTensorType retType =
242-
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
240+
RankedTensorType retType = origType.cloneWithEncoding(dEncoding);
243241
// a & b must be of smem layout
244242
auto aType = cast<RankedTensorType>(adaptor.getA().getType());
245243
auto bType = cast<RankedTensorType>(adaptor.getB().getType());
@@ -255,15 +253,13 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
255253
if (!mlir::isa<triton::gpu::DotOperandEncodingAttr>(aEncoding)) {
256254
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
257255
getContext(), 0, dEncoding, aEltType);
258-
auto dstType =
259-
RankedTensorType::get(aType.getShape(), aEltType, encoding);
256+
auto dstType = aType.cloneWithEncoding(encoding);
260257
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
261258
}
262259
if (!mlir::isa<triton::gpu::DotOperandEncodingAttr>(bEncoding)) {
263260
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
264261
getContext(), 1, dEncoding, bEltType);
265-
auto dstType =
266-
RankedTensorType::get(bType.getShape(), bEltType, encoding);
262+
auto dstType = bType.cloneWithEncoding(encoding);
267263
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
268264
}
269265
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
@@ -313,8 +309,7 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
313309
triton::gpu::BlockedEncodingAttr::get(
314310
getContext(), newRetSizePerThread, retThreadsPerWarp,
315311
retWarpsPerCTA, retOrder, retEncoding.getCTALayout());
316-
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
317-
newRetEncoding);
312+
auto newRetType = retType.cloneWithEncoding(newRetEncoding);
318313
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
319314
op, newRetType, adaptor.getOperands()),
320315
adaptor.getAttributes());
@@ -387,8 +382,7 @@ struct TritonSplitOpPattern : public OpConversionPattern<triton::SplitOp> {
387382
append(defaultEnc.getCTAsPerCGA(), 1),
388383
append(defaultEnc.getCTASplitNum(), 1),
389384
prepend(defaultEnc.getCTAOrder(), rank - 1)));
390-
srcTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(),
391-
srcEnc);
385+
srcTy = srcTy.cloneWithEncoding(srcEnc);
392386
src = rewriter.create<ConvertLayoutOp>(op.getLoc(), srcTy, src);
393387
}
394388

@@ -427,8 +421,7 @@ struct TritonBroadcastPattern
427421
auto srcEncoding = srcType.getEncoding();
428422
if (!srcEncoding)
429423
return failure();
430-
Type retType = RankedTensorType::get(
431-
op.getType().getShape(), op.getType().getElementType(), srcEncoding);
424+
Type retType = op.getType().cloneWithEncoding(srcEncoding);
432425
// Type retType = this->getTypeConverter()->convertType(op.getType());
433426
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
434427
op, retType, adaptor.getOperands()),

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,9 +1244,7 @@ LogicalResult GatherOp::inferReturnTypes(
12441244
auto srcType = cast<RankedTensorType>(adaptor.getSrc().getType());
12451245

12461246
// Shape and encoding of the indices with the element type of the src.
1247-
inferredReturnTypes.push_back(
1248-
RankedTensorType::get(indicesType.getShape(), srcType.getElementType(),
1249-
indicesType.getEncoding()));
1247+
inferredReturnTypes.push_back(indicesType.clone(srcType.getElementType()));
12501248
return success();
12511249
}
12521250

lib/Dialect/Triton/IR/Types.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,16 @@ unsigned getPointeeBitWidth(Type type) {
6464
Type getI1SameShape(Type type) {
6565
auto i1Type = IntegerType::get(type.getContext(), 1);
6666
if (auto tensorTy = dyn_cast<RankedTensorType>(type))
67-
return RankedTensorType::get(tensorTy.getShape(), i1Type,
68-
tensorTy.getEncoding());
67+
return tensorTy.clone(i1Type);
6968
return i1Type;
7069
}
7170

7271
Type getPointeeType(Type type) {
7372
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
7473
// Tensor of pointers
75-
auto shape = tensorTy.getShape();
7674
auto ptrType = dyn_cast<PointerType>(tensorTy.getElementType());
7775
Type pointeeType = ptrType.getPointeeType();
78-
return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding());
76+
return tensorTy.clone(pointeeType);
7977
} else if (auto ptrType = dyn_cast<PointerType>(type)) {
8078
// scalar pointer
8179
Type pointeeType = ptrType.getPointeeType();
@@ -87,17 +85,15 @@ Type getPointeeType(Type type) {
8785
Type getI32SameShape(Type type) {
8886
auto i32Type = IntegerType::get(type.getContext(), 32);
8987
if (auto tensorTy = dyn_cast<RankedTensorType>(type))
90-
return RankedTensorType::get(tensorTy.getShape(), i32Type,
91-
tensorTy.getEncoding());
88+
return tensorTy.clone(i32Type);
9289
return i32Type;
9390
}
9491

9592
Type getPointerTypeSameShape(Type type) {
9693
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
9794
Type elementType = tensorTy.getElementType();
98-
auto shape = tensorTy.getShape();
9995
PointerType ptrType = PointerType::get(elementType, 1);
100-
return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding());
96+
return tensorTy.clone(ptrType);
10197
} else {
10298
return PointerType::get(type, 1);
10399
}

lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ struct MoveBroadcastAfterElementwisePattern
155155

156156
auto srcTy = broadcastOp.getSrc().getType();
157157
auto bcSrcShape = srcTy.getShape();
158-
auto srcEncoding = srcTy.getEncoding();
159158

160159
// Reshape operands to match srcShape
161160
llvm::SmallVector<Value, 4> newOperands;
@@ -167,7 +166,7 @@ struct MoveBroadcastAfterElementwisePattern
167166
}
168167
auto elemTy =
169168
dyn_cast<RankedTensorType>(operand.getType()).getElementType();
170-
auto newTy = RankedTensorType::get(bcSrcShape, elemTy, srcEncoding);
169+
auto newTy = srcTy.clone(bcSrcShape, elemTy);
171170
if (auto splatOp = llvm::dyn_cast<SplatOp>(definingOp)) {
172171
auto newSplat = rewriter.create<SplatOp>(loc, newTy, splatOp.getSrc());
173172
newOperands.push_back(newSplat);
@@ -191,8 +190,7 @@ struct MoveBroadcastAfterElementwisePattern
191190
auto resultTypes = op->getResultTypes();
192191
for (auto resultTy : resultTypes) {
193192
auto elemTy = dyn_cast<RankedTensorType>(resultTy).getElementType();
194-
newResultTypes.push_back(
195-
RankedTensorType::get(bcSrcShape, elemTy, srcEncoding));
193+
newResultTypes.push_back(srcTy.clone(bcSrcShape, elemTy));
196194
}
197195

198196
// Create new op and broadcast results

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,8 +761,7 @@ int32_t LocalAllocOp::getAlignmentOrDefault() {
761761

762762
static Type removeEncodingIfTensor(Type type) {
763763
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
764-
return RankedTensorType::get(tensorType.getShape(),
765-
tensorType.getElementType());
764+
return tensorType.cloneWithEncoding({});
766765
}
767766
return type;
768767
}

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,7 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
355355
auto mmaEnc = NvidiaMmaEncodingAttr::get(
356356
oldRetType.getContext(), versionMajor, versionMinor, warpsPerTile,
357357
CTALayout, instrShape);
358-
auto newRetType = RankedTensorType::get(
359-
oldRetType.getShape(), oldRetType.getElementType(), mmaEnc);
358+
auto newRetType = oldRetType.cloneWithEncoding(mmaEnc);
360359
// convert accumulator
361360
auto oldAcc = dotOp.getOperand(2);
362361
auto newAcc =
@@ -368,8 +367,7 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
368367
auto vType = cast<RankedTensorType>(v.getType());
369368
auto newVEncoding = DotOperandEncodingAttr::get(
370369
v.getContext(), opIdx, newRetType.getEncoding(), minType);
371-
auto newVType = RankedTensorType::get(
372-
vType.getShape(), vType.getElementType(), newVEncoding);
370+
auto newVType = vType.cloneWithEncoding(newVEncoding);
373371
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
374372
};
375373

@@ -476,14 +474,11 @@ static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) {
476474
if (!tensorType)
477475
continue;
478476
Value newOperand = rewriter.create<ConvertLayoutOp>(
479-
operand.get().getLoc(),
480-
RankedTensorType::get(tensorType.getShape(),
481-
tensorType.getElementType(), newLayout),
477+
operand.get().getLoc(), tensorType.cloneWithEncoding(newLayout),
482478
operand.get());
483479
loadOp->setOperand(operand.getOperandNumber(), newOperand);
484480
}
485-
loadOp->getResult(0).setType(RankedTensorType::get(
486-
bType.getShape(), bType.getElementType(), newLayout));
481+
loadOp->getResult(0).setType(bType.cloneWithEncoding(newLayout));
487482
Value newB = loadOp->getResult(0);
488483
rewriter.setInsertionPointAfter(loadOp);
489484
auto cvt = rewriter.create<ConvertLayoutOp>(b.getLoc(), bType, newB);
@@ -549,9 +544,7 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
549544
/*mutableMemory=*/true);
550545
Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout(
551546
instrShape[0], instrShape[1], oldRetType, numWarps);
552-
auto newAccType = RankedTensorType::get(oldRetType.getShape(),
553-
oldRetType.getElementType(),
554-
newDistributedEncoding);
547+
auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding);
555548
Value cvtAcc =
556549
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
557550
auto tokType = rewriter.getType<AsyncTokenType>();
@@ -704,9 +697,7 @@ class ScaledBlockedToMMAv5
704697
/*mutableMemory=*/true);
705698
Attribute newDistributedEncoding =
706699
nvidia_gpu::getTmemCompatibleLayout(m, n, oldRetType, numWarps);
707-
auto newAccType = RankedTensorType::get(oldRetType.getShape(),
708-
oldRetType.getElementType(),
709-
newDistributedEncoding);
700+
auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding);
710701
Value cvtAcc =
711702
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
712703
auto tokType = rewriter.getType<AsyncTokenType>();
@@ -729,10 +720,10 @@ class ScaledBlockedToMMAv5
729720
/*mutableMemory=*/false);
730721
Attribute scaleALayout = getTmemScales(oldScaleAType, numWarps);
731722
Attribute scaleBLayout = getTmemScales(oldScaleBType, numWarps);
732-
RankedTensorType newScaleAType = RankedTensorType::get(
733-
oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleALayout);
734-
RankedTensorType newScaleBType = RankedTensorType::get(
735-
oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleBLayout);
723+
RankedTensorType newScaleAType =
724+
oldScaleAType.cloneWithEncoding(scaleALayout);
725+
RankedTensorType newScaleBType =
726+
oldScaleBType.cloneWithEncoding(scaleBLayout);
736727

737728
auto lhsScale = addSmemStageToScaleLoad(dotOp.getAScale(), rewriter);
738729
auto rhsScale = addSmemStageToScaleLoad(dotOp.getBScale(), rewriter);

0 commit comments

Comments
 (0)