@@ -1188,10 +1188,32 @@ mlir::scf::tileAndFuseProducerOfSlice(
11881188 clonedProducerOp->getResult (resultNumber));
11891189 if (failed (tileAndFuseResult))
11901190 return std::nullopt ;
1191- // Note: Do not delete the candidateSliceOp, since its passed in from the
1192- // caller.
1193- rewriter.replaceAllUsesWith (candidateSliceOp,
1194- tileAndFuseResult->tiledValues [0 ]);
1191+
1192+ // Check if the types are the same. If possible insert a cast. Fail otherwise.
1193+ if (tileAndFuseResult->tiledValues [0 ].getType () !=
1194+ candidateSliceOp.getResult ().getType ()) {
1195+ auto tileAndFuseResultType =
1196+ cast<RankedTensorType>(tileAndFuseResult->tiledValues [0 ].getType ());
1197+ auto candidateSliceOpType =
1198+ cast<RankedTensorType>(candidateSliceOp.getResult ().getType ());
1199+ // We can only cast if the tileAndFuseResultType has a static shape and
1200+ // canidateSliceOp has a dynamic shape. Might be expanded in the future.
1201+ if (!tileAndFuseResultType.hasStaticShape () ||
1202+ candidateSliceOpType.hasStaticShape ()) {
1203+ return std::nullopt ;
1204+ }
1205+
1206+ auto castOp = rewriter.create <tensor::CastOp>(
1207+ candidateSliceOp->getLoc (), candidateSliceOpType, tileAndFuseResult->tiledValues [0 ]);
1208+ // Note: Do not delete the candidateSliceOp, since its passed in from the
1209+ // caller.
1210+ rewriter.replaceAllUsesWith (candidateSliceOp, castOp);
1211+ } else {
1212+ // Note: Do not delete the candidateSliceOp, since its passed in from the
1213+ // caller.
1214+ rewriter.replaceAllUsesWith (candidateSliceOp,
1215+ tileAndFuseResult->tiledValues [0 ]);
1216+ }
11951217 rewriter.eraseOp (clonedCandidateSliceOp);
11961218 rewriter.eraseOp (clonedProducerOp);
11971219
0 commit comments