@@ -63,6 +63,12 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
6363 OpTy packOrUnPackOp) {
6464 static_assert (llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
6565 " applies to only pack or unpack operations" );
66+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
67+ if (isa<linalg::LinalgOp>(packOrUnPackOp)) {
68+ if (!packOrUnPackOp.hasPureTensorSemantics ()) {
69+ return failure ();
70+ }
71+ }
6672 LLVM_DEBUG (
6773 { llvm::dbgs () << " --- Construct PackInfo From an operand ---\n " ; });
6874
@@ -373,6 +379,11 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
373379static FailureOr<GenericOp>
374380bubbleUpPackOpThroughGenericOp (RewriterBase &rewriter, linalg::PackOp packOp,
375381 const ControlPropagationFn &controlFn) {
382+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
383+ if (!packOp.hasPureTensorSemantics ()) {
384+ return failure ();
385+ }
386+
376387 auto genericOp = packOp.getSource ().getDefiningOp <GenericOp>();
377388 if (!genericOp)
378389 return failure ();
@@ -461,6 +472,11 @@ struct BubbleUpPackOpThroughGenericOpPattern
461472
462473 LogicalResult matchAndRewrite (linalg::PackOp packOp,
463474 PatternRewriter &rewriter) const override {
475+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
476+ if (!packOp.hasPureTensorSemantics ()) {
477+ return failure ();
478+ }
479+
464480 auto genericOp =
465481 bubbleUpPackOpThroughGenericOp (rewriter, packOp, controlFn);
466482 if (failed (genericOp))
@@ -483,6 +499,11 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
483499
484500 LogicalResult matchAndRewrite (linalg::PackOp packOp,
485501 PatternRewriter &rewriter) const override {
502+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
503+ if (!packOp.hasPureTensorSemantics ()) {
504+ return failure ();
505+ }
506+
486507 auto padOp = packOp.getSource ().getDefiningOp <tensor::PadOp>();
487508 if (!padOp)
488509 return failure ();
@@ -651,6 +672,11 @@ static LogicalResult
651672bubbleUpPackOpThroughCollapseShape (tensor::CollapseShapeOp collapseOp,
652673 linalg::PackOp packOp,
653674 PatternRewriter &rewriter) {
675+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
676+ if (!packOp.hasPureTensorSemantics ()) {
677+ return failure ();
678+ }
679+
654680 SmallVector<int64_t > innerTileSizes = packOp.getStaticTiles ();
655681 ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
656682 ArrayRef<int64_t > outerDimsPerm = packOp.getOuterDimsPerm ();
@@ -757,6 +783,11 @@ static LogicalResult
757783bubbleUpPackOpThroughExpandShape (tensor::ExpandShapeOp expandOp,
758784 linalg::PackOp packOp,
759785 PatternRewriter &rewriter) {
786+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
787+ if (!packOp.hasPureTensorSemantics ()) {
788+ return failure ();
789+ }
790+
760791 // Outer dimensions permutation is not supported currently.
761792 // TODO: Handle outer_dims_perm variants.
762793 ArrayRef<int64_t > outerDimsPerm = packOp.getOuterDimsPerm ();
@@ -840,6 +871,11 @@ class BubbleUpPackOpThroughReshapeOp final
840871
841872 LogicalResult matchAndRewrite (linalg::PackOp packOp,
842873 PatternRewriter &rewriter) const override {
874+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
875+ if (!packOp.hasPureTensorSemantics ()) {
876+ return failure ();
877+ }
878+
843879 Operation *srcOp = packOp.getSource ().getDefiningOp ();
844880 // Currently only support when the pack op is the only user.
845881 if (!srcOp || !(srcOp->getNumResults () == 1 ) ||
@@ -893,6 +929,11 @@ class BubbleUpPackOpThroughReshapeOp final
893929static LogicalResult pushDownUnPackOpThroughExpandShape (
894930 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
895931 PatternRewriter &rewriter, ControlPropagationFn controlFn) {
932+ // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
933+ if (!unPackOp.hasPureTensorSemantics ()) {
934+ return failure ();
935+ }
936+
896937 // User controlled propagation function.
897938 if (!controlFn (&expandOp.getSrcMutable ()))
898939 return failure ();
@@ -970,6 +1011,11 @@ class PushDownUnPackOpThroughReshapeOp final
9701011
9711012 LogicalResult matchAndRewrite (linalg::UnPackOp unPackOp,
9721013 PatternRewriter &rewriter) const override {
1014+ // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
1015+ if (!unPackOp.hasPureTensorSemantics ()) {
1016+ return failure ();
1017+ }
1018+
9731019 Value result = unPackOp.getResult ();
9741020 // Currently only support unpack op with the single user.
9751021 if (!result.hasOneUse ()) {
@@ -1146,11 +1192,17 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
11461192
11471193 LogicalResult matchAndRewrite (tensor::PadOp padOp,
11481194 PatternRewriter &rewriter) const override {
1195+
11491196 linalg::UnPackOp unpackOp =
11501197 padOp.getSource ().getDefiningOp <linalg::UnPackOp>();
1198+
11511199 if (!unpackOp)
11521200 return failure ();
11531201
1202+ // TODO(issues/129004): Support MemRef PadOp. Temporarily return failure.
1203+ if (!unpackOp.hasPureTensorSemantics ())
1204+ return failure ();
1205+
11541206 if (!controlFn (&padOp.getSourceMutable ()))
11551207 return failure ();
11561208
0 commit comments