@@ -172,8 +172,7 @@ struct TritonExpandDimsPattern
172172 // convert operand to slice of return type
173173 Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get (
174174 getContext (), op.getAxis (), retEncoding);
175- RankedTensorType newArgType = RankedTensorType::get (
176- argType.getShape (), argType.getElementType (), newArgEncoding);
175+ RankedTensorType newArgType = argType.cloneWithEncoding (newArgEncoding);
177176 // construct new op
178177 auto newSrc = rewriter.create <triton::gpu::ConvertLayoutOp>(
179178 op.getLoc (), newArgType, adaptor.getSrc ());
@@ -238,8 +237,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
238237 Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get (
239238 getContext (), origShape, retSizePerThread, retOrder, numWarps,
240239 threadsPerWarp, numCTAs);
241- RankedTensorType retType =
242- RankedTensorType::get (origShape, origType.getElementType (), dEncoding);
240+ RankedTensorType retType = origType.cloneWithEncoding (dEncoding);
243241 // a & b must be of smem layout
244242 auto aType = cast<RankedTensorType>(adaptor.getA ().getType ());
245243 auto bType = cast<RankedTensorType>(adaptor.getB ().getType ());
@@ -255,15 +253,13 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
255253 if (!mlir::isa<triton::gpu::DotOperandEncodingAttr>(aEncoding)) {
256254 Attribute encoding = triton::gpu::DotOperandEncodingAttr::get (
257255 getContext (), 0 , dEncoding, aEltType);
258- auto dstType =
259- RankedTensorType::get (aType.getShape (), aEltType, encoding);
256+ auto dstType = aType.cloneWithEncoding (encoding);
260257 a = rewriter.create <triton::gpu::ConvertLayoutOp>(a.getLoc (), dstType, a);
261258 }
262259 if (!mlir::isa<triton::gpu::DotOperandEncodingAttr>(bEncoding)) {
263260 Attribute encoding = triton::gpu::DotOperandEncodingAttr::get (
264261 getContext (), 1 , dEncoding, bEltType);
265- auto dstType =
266- RankedTensorType::get (bType.getShape (), bEltType, encoding);
262+ auto dstType = bType.cloneWithEncoding (encoding);
267263 b = rewriter.create <triton::gpu::ConvertLayoutOp>(b.getLoc (), dstType, b);
268264 }
269265 c = rewriter.create <triton::gpu::ConvertLayoutOp>(c.getLoc (), retType, c);
@@ -313,8 +309,7 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
313309 triton::gpu::BlockedEncodingAttr::get (
314310 getContext (), newRetSizePerThread, retThreadsPerWarp,
315311 retWarpsPerCTA, retOrder, retEncoding.getCTALayout ());
316- auto newRetType = RankedTensorType::get (retShape, retType.getElementType (),
317- newRetEncoding);
312+ auto newRetType = retType.cloneWithEncoding (newRetEncoding);
318313 addNamedAttrs (rewriter.replaceOpWithNewOp <triton::CatOp>(
319314 op, newRetType, adaptor.getOperands ()),
320315 adaptor.getAttributes ());
@@ -387,8 +382,7 @@ struct TritonSplitOpPattern : public OpConversionPattern<triton::SplitOp> {
387382 append (defaultEnc.getCTAsPerCGA (), 1 ),
388383 append (defaultEnc.getCTASplitNum (), 1 ),
389384 prepend (defaultEnc.getCTAOrder (), rank - 1 )));
390- srcTy = RankedTensorType::get (srcTy.getShape (), srcTy.getElementType (),
391- srcEnc);
385+ srcTy = srcTy.cloneWithEncoding (srcEnc);
392386 src = rewriter.create <ConvertLayoutOp>(op.getLoc (), srcTy, src);
393387 }
394388
@@ -427,8 +421,7 @@ struct TritonBroadcastPattern
427421 auto srcEncoding = srcType.getEncoding ();
428422 if (!srcEncoding)
429423 return failure ();
430- Type retType = RankedTensorType::get (
431- op.getType ().getShape (), op.getType ().getElementType (), srcEncoding);
424+ Type retType = op.getType ().cloneWithEncoding (srcEncoding);
432425 // Type retType = this->getTypeConverter()->convertType(op.getType());
433426 addNamedAttrs (rewriter.replaceOpWithNewOp <triton::BroadcastOp>(
434427 op, retType, adaptor.getOperands ()),
0 commit comments