@@ -4698,6 +4698,111 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
46984698// ===----------------------------------------------------------------------===//
46994699// Common Canonicalizers and Folders.
47004700// ===----------------------------------------------------------------------===//
4701+ bool foldTensorCastPrecondition (DestinationStyleOpInterface op) {
4702+ // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4703+ // 2. Exclude DPS ops that are also LoopLike from this interface as they
4704+ // might need special handling of attached regions.
4705+ if (isa<InsertSliceOp>(op.getOperation ()) ||
4706+ isa<LoopLikeOpInterface>(op.getOperation ()))
4707+ return false ;
4708+
4709+ // If no operand comes from a tensor::CastOp and can be folded then fail.
4710+ bool hasTensorCastOperand =
4711+ llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
4712+ if (llvm::isa<BlockArgument>(opOperand.get ()))
4713+ return false ;
4714+ auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4715+ return castOp && canFoldIntoConsumerOp (castOp);
4716+ });
4717+
4718+ return hasTensorCastOperand;
4719+ }
4720+
4721+ static SmallVector<Value> getNewOperands (DestinationStyleOpInterface op,
4722+ SmallVector<Type> &newResTy) {
4723+ SmallVector<Value> newOperands;
4724+ newOperands.reserve (op->getNumOperands ());
4725+
4726+ // Assumes that the result has dpsInits followed by nonDpsInits.
4727+ int64_t dpsInitIdx = 0 ;
4728+ for (OpOperand &opOperand : op->getOpOperands ()) {
4729+ auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4730+ bool fold = canFoldIntoConsumerOp (tensorCastOp);
4731+ newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
4732+ if (op.isDpsInit (&opOperand) &&
4733+ !llvm::isa<MemRefType>(newOperands.back ().getType ()))
4734+ newResTy[dpsInitIdx++] = newOperands.back ().getType ();
4735+ }
4736+ return newOperands;
4737+ }
4738+
4739+ // / Folds a tensor.cast op into a consuming tensor::PackOp op if the
4740+ // / `tensor.cast` has source that is more static than the consuming op.
4741+ // /
4742+ // / Example:
4743+ // / ```mlir
4744+ // / %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4745+ // / %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4746+ // / ```
4747+ // /
4748+ // / folds into:
4749+ // /
4750+ // / ```mlir
4751+ // / %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4752+ // / ```
4753+ struct FoldTensorCastPackOp : public OpRewritePattern <PackOp> {
4754+ using OpRewritePattern<PackOp>::OpRewritePattern;
4755+
4756+ LogicalResult matchAndRewrite (PackOp op,
4757+ PatternRewriter &rewriter) const override {
4758+ if (!foldTensorCastPrecondition (op))
4759+ return failure ();
4760+
4761+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4762+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4763+
4764+ // Get the updated mixed-tile-sizes attribute.
4765+ SmallVector<OpFoldResult> newMixedTileSizes;
4766+ for (auto it : llvm::zip (cast<ShapedType>(newResultTypes[0 ])
4767+ .getShape ()
4768+ .take_back (op.getMixedTiles ().size ()),
4769+ op.getMixedTiles ())) {
4770+ int64_t shape = std::get<0 >(it);
4771+ if (shape == ShapedType::kDynamic ) {
4772+ newMixedTileSizes.push_back (std::get<1 >(it));
4773+ continue ;
4774+ }
4775+
4776+ if (Attribute attr =
4777+ llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4778+ // Already a constant
4779+ newMixedTileSizes.push_back (std::get<1 >(it));
4780+ } else {
4781+ int64_t tileSize = getConstantIntValue (std::get<1 >(it)).value ();
4782+ assert (tileSize == shape && " tile size and dim size don't match!" );
4783+ newMixedTileSizes.push_back (
4784+ (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4785+ }
4786+ }
4787+
4788+ // Clone op.
4789+ PackOp newOp = rewriter.create <PackOp>(
4790+ op.getLoc (), newOperands[0 ], newOperands[1 ], op.getInnerDimsPos (),
4791+ newMixedTileSizes, op.getPaddingValue (), op.getOuterDimsPerm ());
4792+
4793+ // Replace op.
4794+ Value oldResult = op.getResult ();
4795+ Value newResult = newOp.getResult ();
4796+ Value replacement = (newResult.getType () != oldResult.getType ())
4797+ ? rewriter.create <tensor::CastOp>(
4798+ op->getLoc (), oldResult.getType (), newResult)
4799+ : newResult;
4800+
4801+ rewriter.replaceOp (op, {replacement});
4802+
4803+ return success ();
4804+ }
4805+ };
47014806
47024807// / Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
47034808// / the `tensor.cast` has source that is more static than the consuming op.
@@ -4722,42 +4827,17 @@ struct FoldTensorCastProducerOp
47224827
47234828 LogicalResult matchAndRewrite (DestinationStyleOpInterface op,
47244829 PatternRewriter &rewriter) const override {
4725- // InsertSliceOp has its own logic about folding tensor.cast ops.
4726- if (isa<InsertSliceOp>(op.getOperation ()))
4727- return failure ();
47284830
4729- // Exclude DPS ops that are also LoopLike from this interface as they
4730- // might need special handling of attached regions.
4731- if (isa<LoopLikeOpInterface>(op.getOperation ()))
4831+ // Reject tensor::PackOp - there's dedicated pattern for that instead.
4832+ if (!foldTensorCastPrecondition (op) || dyn_cast<tensor::PackOp>(*op))
47324833 return failure ();
47334834
4734- // If no operand comes from a tensor::CastOp and can be folded then fail.
4735- bool hasTensorCastOperand =
4736- llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
4737- if (llvm::isa<BlockArgument>(opOperand.get ()))
4738- return false ;
4739- auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4740- return castOp && canFoldIntoConsumerOp (castOp);
4741- });
4742- if (!hasTensorCastOperand)
4743- return failure ();
4835+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4836+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
47444837
4745- SmallVector<Type, 4 > newResultTypes (op->getResultTypes ());
4746- SmallVector<Value, 4 > newOperands;
4747- newOperands.reserve (op->getNumOperands ());
4748- // Assumes that the result has dpsInits followed by nonDpsInits.
4749- int64_t dpsInitIdx = 0 ;
4750- for (OpOperand &opOperand : op->getOpOperands ()) {
4751- auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4752- bool fold = canFoldIntoConsumerOp (tensorCastOp);
4753- newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
4754- if (op.isDpsInit (&opOperand) &&
4755- !llvm::isa<MemRefType>(newOperands.back ().getType ()))
4756- newResultTypes[dpsInitIdx++] = newOperands.back ().getType ();
4757- }
4838+ // Clone op
4839+ auto newOp = clone (rewriter, op, newResultTypes, newOperands);
47584840
4759- // Clone op.
4760- Operation *newOp = clone (rewriter, op, newResultTypes, newOperands);
47614841 SmallVector<Value, 4 > replacements;
47624842 replacements.reserve (newOp->getNumResults ());
47634843 for (auto [oldResult, newResult] :
@@ -4781,6 +4861,7 @@ struct FoldTensorCastProducerOp
47814861
47824862void TensorDialect::getCanonicalizationPatterns (
47834863 RewritePatternSet &results) const {
4864+ results.add <FoldTensorCastPackOp>(getContext ());
47844865 results.add <FoldTensorCastProducerOp>(getContext ());
47854866}
47864867
0 commit comments