@@ -715,51 +715,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
715715// ===----------------------------------------------------------------------===//
716716
717717namespace {
718- // / If the source/target of a CopyOp is a CastOp that does not modify the shape
719- // / and element type, the cast can be skipped. Such CastOps only cast the layout
720- // / of the type.
721- struct FoldCopyOfCast : public OpRewritePattern <CopyOp> {
722- using OpRewritePattern<CopyOp>::OpRewritePattern;
723-
724- LogicalResult matchAndRewrite (CopyOp copyOp,
725- PatternRewriter &rewriter) const override {
726- bool modified = false ;
727-
728- // Check source.
729- if (auto castOp = copyOp.getSource ().getDefiningOp <CastOp>()) {
730- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
731- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
732-
733- if (fromType && toType) {
734- if (fromType.getShape () == toType.getShape () &&
735- fromType.getElementType () == toType.getElementType ()) {
736- rewriter.modifyOpInPlace (copyOp, [&] {
737- copyOp.getSourceMutable ().assign (castOp.getSource ());
738- });
739- modified = true ;
740- }
741- }
742- }
743-
744- // Check target.
745- if (auto castOp = copyOp.getTarget ().getDefiningOp <CastOp>()) {
746- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
747- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
748-
749- if (fromType && toType) {
750- if (fromType.getShape () == toType.getShape () &&
751- fromType.getElementType () == toType.getElementType ()) {
752- rewriter.modifyOpInPlace (copyOp, [&] {
753- copyOp.getTargetMutable ().assign (castOp.getSource ());
754- });
755- modified = true ;
756- }
757- }
758- }
759-
760- return success (modified);
761- }
762- };
763718
764719// / Fold memref.copy(%x, %x).
765720struct FoldSelfCopy : public OpRewritePattern <CopyOp> {
@@ -797,22 +752,28 @@ struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
797752
798753void CopyOp::getCanonicalizationPatterns (RewritePatternSet &results,
799754 MLIRContext *context) {
800- results.add <FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
755+ results.add <FoldEmptyCopy, FoldSelfCopy>(context);
801756}
802757
803- LogicalResult CopyOp::fold (FoldAdaptor adaptor,
804- SmallVectorImpl<OpFoldResult> &results) {
805- // / copy(memrefcast) -> copy
806- bool folded = false ;
807- Operation *op = *this ;
758+ // / If the source/target of a CopyOp is a CastOp that does not modify the shape
759+ // / and element type, the cast can be skipped. Such CastOps only cast the layout
760+ // / of the type.
761+ static LogicalResult FoldCopyOfCast (CopyOp op) {
808762 for (OpOperand &operand : op->getOpOperands ()) {
809763 auto castOp = operand.get ().getDefiningOp <memref::CastOp>();
810764 if (castOp && memref::CastOp::canFoldIntoConsumerOp (castOp)) {
811765 operand.set (castOp.getOperand ());
812- folded = true ;
766+ return success () ;
813767 }
814768 }
815- return success (folded);
769+ return failure ();
770+ }
771+
772+ LogicalResult CopyOp::fold (FoldAdaptor adaptor,
773+ SmallVectorImpl<OpFoldResult> &results) {
774+
775+ // / copy(memrefcast) -> copy
776+ return FoldCopyOfCast (*this );
816777}
817778
818779// ===----------------------------------------------------------------------===//
0 commit comments