@@ -390,7 +390,7 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
390390 unsigned inputMapIdx = 0 , filterMapIdx = 1 ,
391391 outputMapIdx = indexingMaps.size () - 1 ;
392392 AffineExpr inpExpr = getAffineMapDim (indexingMaps, inputMapIdx, iDim);
393- auto addExpr = dyn_cast <AffineBinaryOpExpr>(inpExpr);
393+ auto addExpr = dyn_cast_or_null <AffineBinaryOpExpr>(inpExpr);
394394 if (!addExpr || addExpr.getKind () != AffineExprKind::Add)
395395 return false ;
396396
@@ -434,6 +434,263 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
434434 })));
435435}
436436
437+ // #inputMap = affine_map<(W, w) -> (W + w)>
438+ // #filterMap = affine_map<(W, w) -> (w)>
439+ // #outputMap = affine_map<(W, w) -> (W)>
440+ template <>
441+ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
442+ SmallVector<int64_t > *dilations,
443+ SmallVector<int64_t > *strides) {
444+ if (isa<linalg::Conv1DOp>(op))
445+ return true ;
446+
447+ assert (isaConvolutionOpInterface (op) &&
448+ " expected op to implement ConvolutionOpInterface" );
449+
450+ *dilations = SmallVector<int64_t >(1 , 1 );
451+ *strides = SmallVector<int64_t >(1 , 1 );
452+ MLIRContext *context = op->getContext ();
453+ AffineExpr W = getAffineDimExpr (0 , context);
454+ AffineExpr w = getAffineDimExpr (1 , context);
455+ ArrayAttr indexingMaps = op.getIndexingMaps ();
456+ // First fetch dilations/strides :-
457+ // Match: W * stride + w * dilation
458+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 0 , /* fDim=*/ 0 ,
459+ /* oDim=*/ 0 , (*dilations)[0 ], (*strides)[0 ]))
460+ return false ;
461+ // Match expected indexing maps
462+ if (!convLayoutMatches (
463+ {/* inputMap=*/ {W * (*strides)[0 ] + w * (*dilations)[0 ]},
464+ /* filterMap=*/ {w},
465+ /* outputMap=*/ {W}},
466+ indexingMaps, context))
467+ return false ;
468+ // Match body
469+ Block *body = op.getBlock ();
470+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
471+ Value yieldVal = yieldOp.getOperand (0 );
472+ return bodyMatcherForConvolutionOps (yieldVal, body);
473+ }
474+
475+ // #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
476+ // #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)>
477+ // #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)>
478+ template <>
479+ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
480+ LinalgOp op, SmallVector<int64_t > *dilations,
481+ SmallVector<int64_t > *strides) {
482+ if (isa<linalg::Conv1DNwcWcfOp>(op))
483+ return true ;
484+
485+ assert (isaConvolutionOpInterface (op) &&
486+ " expected op to implement ConvolutionOpInterface" );
487+
488+ *dilations = SmallVector<int64_t >(1 , 1 );
489+ *strides = SmallVector<int64_t >(1 , 1 );
490+ MLIRContext *context = op->getContext ();
491+ AffineExpr N = getAffineDimExpr (0 , context);
492+ AffineExpr W = getAffineDimExpr (1 , context);
493+ AffineExpr F = getAffineDimExpr (2 , context);
494+ AffineExpr w = getAffineDimExpr (3 , context);
495+ AffineExpr c = getAffineDimExpr (4 , context);
496+ ArrayAttr indexingMaps = op.getIndexingMaps ();
497+ // First fetch dilations/strides :-
498+ // Match: W * stride + w * dilation
499+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
500+ /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
501+ return false ;
502+ // Match expected indexing maps
503+ if (!convLayoutMatches (
504+ {/* inputMap=*/ {N, W * (*strides)[0 ] + w * (*dilations)[0 ], c},
505+ /* filterMap=*/ {w, c, F},
506+ /* outputMap=*/ {N, W, F}},
507+ indexingMaps, context))
508+ return false ;
509+ // Match body
510+ Block *body = op.getBlock ();
511+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
512+ Value yieldVal = yieldOp.getOperand (0 );
513+ return bodyMatcherForConvolutionOps (yieldVal, body);
514+ }
515+
516+ // #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
517+ // #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)>
518+ // #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)>
519+ template <>
520+ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
521+ LinalgOp op, SmallVector<int64_t > *dilations,
522+ SmallVector<int64_t > *strides) {
523+ if (isa<linalg::Conv1DNcwFcwOp>(op))
524+ return true ;
525+
526+ assert (isaConvolutionOpInterface (op) &&
527+ " expected op to implement ConvolutionOpInterface" );
528+
529+ *dilations = SmallVector<int64_t >(1 , 1 );
530+ *strides = SmallVector<int64_t >(1 , 1 );
531+ MLIRContext *context = op->getContext ();
532+ AffineExpr N = getAffineDimExpr (0 , context);
533+ AffineExpr F = getAffineDimExpr (1 , context);
534+ AffineExpr W = getAffineDimExpr (2 , context);
535+ AffineExpr c = getAffineDimExpr (3 , context);
536+ AffineExpr w = getAffineDimExpr (4 , context);
537+ ArrayAttr indexingMaps = op.getIndexingMaps ();
538+ // First fetch dilations/strides :-
539+ // Match: W * stride + w * dilation
540+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 ,
541+ /* oDim=*/ 2 , (*dilations)[0 ], (*strides)[0 ]))
542+ return false ;
543+ // Match expected indexing maps
544+ if (!convLayoutMatches (
545+ {/* inputMap=*/ {N, c, W * (*strides)[0 ] + w * (*dilations)[0 ]},
546+ /* filterMap=*/ {F, c, w},
547+ /* outputMap=*/ {N, F, W}},
548+ indexingMaps, context))
549+ return false ;
550+ // Match body
551+ Block *body = op.getBlock ();
552+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
553+ Value yieldVal = yieldOp.getOperand (0 );
554+ return bodyMatcherForConvolutionOps (yieldVal, body);
555+ }
556+
557+ // #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)>
558+ // #filterMap = affine_map<(H, W, h, w) -> (h, w)>
559+ // #outputMap = affine_map<(H, W, h, w) -> (H, W)>
560+ template <>
561+ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
562+ SmallVector<int64_t > *dilations,
563+ SmallVector<int64_t > *strides) {
564+ if (isa<linalg::Conv2DOp>(op))
565+ return true ;
566+
567+ assert (isaConvolutionOpInterface (op) &&
568+ " expected op to implement ConvolutionOpInterface" );
569+
570+ *dilations = SmallVector<int64_t >(2 , 1 );
571+ *strides = SmallVector<int64_t >(2 , 1 );
572+ MLIRContext *context = op->getContext ();
573+ AffineExpr H = getAffineDimExpr (0 , context);
574+ AffineExpr W = getAffineDimExpr (1 , context);
575+ AffineExpr h = getAffineDimExpr (2 , context);
576+ AffineExpr w = getAffineDimExpr (3 , context);
577+ ArrayAttr indexingMaps = op.getIndexingMaps ();
578+ // First fetch dilations/strides :-
579+ // Match: H * stride + h * dilation
580+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 0 , /* fDim=*/ 0 ,
581+ /* oDim=*/ 0 , (*dilations)[0 ], (*strides)[0 ]))
582+ return false ;
583+ // Match: W * stride + w * dilation
584+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 1 ,
585+ /* oDim=*/ 1 , (*dilations)[1 ], (*strides)[1 ]))
586+ return false ;
587+ // Match expected indexing maps
588+ if (!convLayoutMatches (
589+ {/* inputMap=*/ {H * (*strides)[0 ] + h * (*dilations)[0 ],
590+ W * (*strides)[1 ] + w * (*dilations)[1 ]},
591+ /* filterMap=*/ {h, w},
592+ /* outputMap=*/ {H, W}},
593+ indexingMaps, context))
594+ return false ;
595+ // Match body
596+ Block *body = op.getBlock ();
597+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
598+ Value yieldVal = yieldOp.getOperand (0 );
599+ return bodyMatcherForConvolutionOps (yieldVal, body);
600+ }
601+
602+ // #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
603+ // #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
604+ // #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
605+ template <>
606+ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
607+ SmallVector<int64_t > *dilations,
608+ SmallVector<int64_t > *strides) {
609+ if (isa<linalg::Conv3DOp>(op))
610+ return true ;
611+
612+ assert (isaConvolutionOpInterface (op) &&
613+ " expected op to implement ConvolutionOpInterface" );
614+
615+ *dilations = SmallVector<int64_t >(3 , 1 );
616+ *strides = SmallVector<int64_t >(3 , 1 );
617+ MLIRContext *context = op->getContext ();
618+ AffineExpr D = getAffineDimExpr (0 , context);
619+ AffineExpr H = getAffineDimExpr (1 , context);
620+ AffineExpr W = getAffineDimExpr (2 , context);
621+ AffineExpr d = getAffineDimExpr (3 , context);
622+ AffineExpr h = getAffineDimExpr (4 , context);
623+ AffineExpr w = getAffineDimExpr (5 , context);
624+ ArrayAttr indexingMaps = op.getIndexingMaps ();
625+ // First fetch dilations/strides :-
626+ // Match: D * stride + d * dilation
627+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 0 , /* fDim=*/ 0 ,
628+ /* oDim=*/ 0 , (*dilations)[0 ], (*strides)[0 ]))
629+ return false ;
630+ // Match: H * stride + h * dilation
631+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 1 ,
632+ /* oDim=*/ 1 , (*dilations)[1 ], (*strides)[1 ]))
633+ return false ;
634+ // Match: W * stride + w * dilation
635+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 ,
636+ /* oDim=*/ 2 , (*dilations)[2 ], (*strides)[2 ]))
637+ return false ;
638+ // Match expected indexing maps
639+ if (!convLayoutMatches (
640+ {/* inputMap=*/ {D * (*strides)[0 ] + d * (*dilations)[0 ],
641+ H * (*strides)[1 ] + h * (*dilations)[1 ],
642+ W * (*strides)[2 ] + w * (*dilations)[2 ]},
643+ /* filterMap=*/ {d, h, w},
644+ /* outputMap=*/ {D, H, W}},
645+ indexingMaps, context))
646+ return false ;
647+ // Match body
648+ Block *body = op.getBlock ();
649+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
650+ Value yieldVal = yieldOp.getOperand (0 );
651+ return bodyMatcherForConvolutionOps (yieldVal, body);
652+ }
653+
654+ // #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
655+ // #filterMap = affine_map<(N, W, C, w) -> (C, w)>
656+ // #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
657+ template <>
658+ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
659+ LinalgOp op, SmallVector<int64_t > *dilations,
660+ SmallVector<int64_t > *strides) {
661+ if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
662+ return true ;
663+
664+ assert (isaConvolutionOpInterface (op) &&
665+ " expected op to implement ConvolutionOpInterface" );
666+
667+ *dilations = SmallVector<int64_t >(1 , 1 );
668+ *strides = SmallVector<int64_t >(1 , 1 );
669+ MLIRContext *context = op->getContext ();
670+ AffineExpr N = getAffineDimExpr (0 , context);
671+ AffineExpr W = getAffineDimExpr (1 , context);
672+ AffineExpr C = getAffineDimExpr (2 , context);
673+ AffineExpr w = getAffineDimExpr (3 , context);
674+ ArrayAttr indexingMaps = op.getIndexingMaps ();
675+ // First fetch dilations/strides :-
676+ // Match: W * stride + w * dilation
677+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
678+ /* oDim=*/ 2 , (*dilations)[0 ], (*strides)[0 ]))
679+ return false ;
680+ // Match expected indexing maps
681+ if (!convLayoutMatches (
682+ {/* inputMap=*/ {N, C, W * (*strides)[0 ] + w * (*dilations)[0 ]},
683+ /* filterMap=*/ {C, w},
684+ /* outputMap=*/ {N, C, W}},
685+ indexingMaps, context))
686+ return false ;
687+ // Match body
688+ Block *body = op.getBlock ();
689+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
690+ Value yieldVal = yieldOp.getOperand (0 );
691+ return bodyMatcherForConvolutionOps (yieldVal, body);
692+ }
693+
437694// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
438695// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
439696// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
@@ -474,6 +731,47 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
474731 return bodyMatcherForConvolutionOps (yieldVal, body);
475732}
476733
734+ // #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
735+ // #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
736+ // #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
737+ template <>
738+ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
739+ LinalgOp op, SmallVector<int64_t > *dilations,
740+ SmallVector<int64_t > *strides) {
741+ if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
742+ return true ;
743+
744+ assert (isaConvolutionOpInterface (op) &&
745+ " expected op to implement ConvolutionOpInterface" );
746+
747+ *dilations = SmallVector<int64_t >(1 , 1 );
748+ *strides = SmallVector<int64_t >(1 , 1 );
749+ MLIRContext *context = op->getContext ();
750+ AffineExpr N = getAffineDimExpr (0 , context);
751+ AffineExpr W = getAffineDimExpr (1 , context);
752+ AffineExpr C = getAffineDimExpr (2 , context);
753+ AffineExpr CM = getAffineDimExpr (3 , context);
754+ AffineExpr w = getAffineDimExpr (4 , context);
755+ ArrayAttr indexingMaps = op.getIndexingMaps ();
756+ // First fetch dilations/strides :-
757+ // Match: W * stride + w * dilation
758+ if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
759+ /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
760+ return false ;
761+ // Match expected indexing maps
762+ if (!convLayoutMatches (
763+ {/* inputMap=*/ {N, W * (*strides)[0 ] + w * (*dilations)[0 ], C},
764+ /* filterMap=*/ {w, C, CM},
765+ /* outputMap=*/ {N, W, C, CM}},
766+ indexingMaps, context))
767+ return false ;
768+ // Match body
769+ Block *body = op.getBlock ();
770+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
771+ Value yieldVal = yieldOp.getOperand (0 );
772+ return bodyMatcherForConvolutionOps (yieldVal, body);
773+ }
774+
477775// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
478776// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
479777// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
0 commit comments