@@ -442,6 +442,20 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
442442// Matchers for specific convolution operation.
443443// ---------------------------------------------
444444
445+ // / Returns true if the given indexing maps matches with the expected indexing
446+ // / maps.
447+ static bool convLayoutMatches (ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
448+ ArrayAttr indexingMaps, MLIRContext *context) {
449+ SmallVector<AffineMap, 4 > expectedIndexingMaps =
450+ AffineMap::inferFromExprList (mapListExpected, context);
451+ return indexingMaps ==
452+ ArrayAttr::get (
453+ context, llvm::to_vector<4 >(llvm::map_range (
454+ expectedIndexingMaps, [&](AffineMap m) -> Attribute {
455+ return AffineMapAttr::get (m);
456+ })));
457+ }
458+
445459// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
446460// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
447461// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
@@ -459,25 +473,25 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
459473 if (!verifyConvIndexingMapSizes (indexingMaps, {3 , 2 , 3 }))
460474 return false ;
461475
462- unsigned inputMapIdx = 0 , filterMapIdx = 1 , outputMapIdx = 2 ;
463-
464476 *dilations = SmallVector<int64_t >(1 , 1 );
465477 *strides = SmallVector<int64_t >(1 , 1 );
466- // Match: N
467- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
468- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
469- return false ;
470- // Match: C
471- if (getAffineMapDim (indexingMaps, inputMapIdx, 2 ) !=
472- getAffineMapDim (indexingMaps, filterMapIdx, 1 ))
473- return false ;
474- if (getAffineMapDim (indexingMaps, inputMapIdx, 2 ) !=
475- getAffineMapDim (indexingMaps, outputMapIdx, 2 ))
476- return false ;
477- // Match: W + w
478+ MLIRContext *context = op->getContext ();
479+ AffineExpr N = getAffineDimExpr (0 , context);
480+ AffineExpr W = getAffineDimExpr (1 , context);
481+ AffineExpr C = getAffineDimExpr (2 , context);
482+ AffineExpr w = getAffineDimExpr (3 , context);
483+ // First fetch dilations/strides :-
484+ // Match: W * stride + w * dilation
478485 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
479486 /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
480487 return false ;
488+ // Match expected indexing maps
489+ if (!convLayoutMatches (
490+ {/* inputMap=*/ {N, W * (*strides)[0 ] + w * (*dilations)[0 ], C},
491+ /* filterMap=*/ {w, C},
492+ /* outputMap=*/ {N, W, C}},
493+ indexingMaps, context))
494+ return false ;
481495 // Match body
482496 Block *body = op.getBlock ();
483497 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
@@ -504,29 +518,32 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
504518 if (!verifyConvIndexingMapSizes (indexingMaps, {4 , 3 , 4 }))
505519 return false ;
506520
507- unsigned inputMapIdx = 0 , filterMapIdx = 1 , outputMapIdx = 2 ;
508-
509521 *dilations = SmallVector<int64_t >(2 , 1 );
510522 *strides = SmallVector<int64_t >(2 , 1 );
511- // Match: N
512- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
513- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
514- return false ;
515- // Match: C
516- if (getAffineMapDim (indexingMaps, inputMapIdx, 1 ) !=
517- getAffineMapDim (indexingMaps, filterMapIdx, 0 ))
518- return false ;
519- if (getAffineMapDim (indexingMaps, inputMapIdx, 1 ) !=
520- getAffineMapDim (indexingMaps, outputMapIdx, 1 ))
521- return false ;
522- // Match: H + h
523+ MLIRContext *context = op->getContext ();
524+ AffineExpr N = getAffineDimExpr (0 , context);
525+ AffineExpr H = getAffineDimExpr (1 , context);
526+ AffineExpr W = getAffineDimExpr (2 , context);
527+ AffineExpr C = getAffineDimExpr (3 , context);
528+ AffineExpr h = getAffineDimExpr (4 , context);
529+ AffineExpr w = getAffineDimExpr (5 , context);
530+ // First fetch dilations/strides :-
531+ // Match: H * stride + h * dilation
523532 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
524533 /* oDim=*/ 2 , (*dilations)[0 ], (*strides)[0 ]))
525534 return false ;
526- // Match: W + w
535+ // Match: W * stride + w * dilation
527536 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 ,
528537 /* oDim=*/ 3 , (*dilations)[1 ], (*strides)[1 ]))
529538 return false ;
539+ // Match expected indexing maps
540+ if (!convLayoutMatches (
541+ {/* inputMap=*/ {N, C, H * (*strides)[0 ] + h * (*dilations)[0 ],
542+ W * (*strides)[1 ] + w * (*dilations)[1 ]},
543+ /* filterMap=*/ {C, h, w},
544+ /* outputMap=*/ {N, C, H, W}},
545+ indexingMaps, context))
546+ return false ;
530547 // Match body
531548 Block *body = op.getBlock ();
532549 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
@@ -556,36 +573,39 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
556573 if (!verifyConvIndexingMapSizes (indexingMaps, {5 , 5 , 6 }))
557574 return false ;
558575
559- unsigned inputMapIdx = 0 , filterMapIdx = 1 , outputMapIdx = 2 ;
560-
561576 *dilations = SmallVector<int64_t >(3 , 1 );
562577 *strides = SmallVector<int64_t >(3 , 1 );
563- // Match: N
564- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
565- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
566- return false ;
567- // Match: D + d
578+ MLIRContext *context = op->getContext ();
579+ AffineExpr N = getAffineDimExpr (0 , context);
580+ AffineExpr D = getAffineDimExpr (1 , context);
581+ AffineExpr H = getAffineDimExpr (2 , context);
582+ AffineExpr W = getAffineDimExpr (3 , context);
583+ AffineExpr CM = getAffineDimExpr (4 , context);
584+ AffineExpr d = getAffineDimExpr (5 , context);
585+ AffineExpr h = getAffineDimExpr (6 , context);
586+ AffineExpr w = getAffineDimExpr (7 , context);
587+ AffineExpr C = getAffineDimExpr (8 , context);
588+ // First fetch dilations/strides :-
589+ // Match: D * stride + d * dilation
568590 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
569591 /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
570592 return false ;
571- // Match: H + h
593+ // Match: H * stride + h * dilation
572594 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
573595 /* oDim=*/ 2 , (*dilations)[1 ], (*strides)[1 ]))
574596 return false ;
575- // Match: W + w
597+ // Match: W * stride + w * dilation
576598 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 ,
577599 /* oDim=*/ 3 , (*dilations)[2 ], (*strides)[2 ]))
578600 return false ;
579- // Match: C
580- if (getAffineMapDim (indexingMaps, inputMapIdx, 4 ) !=
581- getAffineMapDim (indexingMaps, filterMapIdx, 3 ))
582- return false ;
583- if (getAffineMapDim (indexingMaps, inputMapIdx, 4 ) !=
584- getAffineMapDim (indexingMaps, outputMapIdx, 4 ))
585- return false ;
586- // Match: CM
587- if (getAffineMapDim (indexingMaps, filterMapIdx, 4 ) !=
588- getAffineMapDim (indexingMaps, outputMapIdx, 5 ))
601+ // Match expected indexing maps
602+ if (!convLayoutMatches (
603+ {/* inputMap=*/ {N, D * (*strides)[0 ] + d * (*dilations)[0 ],
604+ H * (*strides)[1 ] + h * (*dilations)[1 ],
605+ W * (*strides)[2 ] + w * (*dilations)[2 ], C},
606+ /* filterMap=*/ {d, h, w, C, CM},
607+ /* outputMap=*/ {N, D, H, W, C, CM}},
608+ indexingMaps, context))
589609 return false ;
590610 // Match body
591611 Block *body = op.getBlock ();
@@ -613,25 +633,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
613633 if (!verifyConvIndexingMapSizes (indexingMaps, {4 , 2 , 4 }))
614634 return false ;
615635
616- unsigned inputMapIdx = 0 , outputMapIdx = 2 ;
617-
618636 *dilations = SmallVector<int64_t >(2 , 1 );
619637 *strides = SmallVector<int64_t >(2 , 1 );
620- // Match: N
621- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
622- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
623- return false ;
624- // Match: H + h
638+ MLIRContext *context = op->getContext ();
639+ AffineExpr N = getAffineDimExpr (0 , context);
640+ AffineExpr H = getAffineDimExpr (1 , context);
641+ AffineExpr W = getAffineDimExpr (2 , context);
642+ AffineExpr C = getAffineDimExpr (3 , context);
643+ AffineExpr h = getAffineDimExpr (4 , context);
644+ AffineExpr w = getAffineDimExpr (5 , context);
645+ // First fetch dilations/strides :-
646+ // Match: H * stride + h * dilation
625647 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
626648 /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
627649 return false ;
628- // Match: W + w
650+ // Match: W * stride + w * dilation
629651 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
630652 /* oDim=*/ 2 , (*dilations)[1 ], (*strides)[1 ]))
631653 return false ;
632- // Match: C
633- if (getAffineMapDim (indexingMaps, inputMapIdx, 3 ) !=
634- getAffineMapDim (indexingMaps, outputMapIdx, 3 ))
654+ // Match expected indexing maps
655+ if (!convLayoutMatches (
656+ {/* inputMap=*/ {N, H * (*strides)[0 ] + h * (*dilations)[0 ],
657+ W * (*strides)[1 ] + w * (*dilations)[1 ], C},
658+ /* filterMap=*/ {h, w},
659+ /* outputMap=*/ {N, H, W, C}},
660+ indexingMaps, context))
635661 return false ;
636662 // Match body
637663 Block *body = op.getBlock ();
@@ -659,25 +685,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
659685 if (!verifyConvIndexingMapSizes (indexingMaps, {4 , 2 , 4 }))
660686 return false ;
661687
662- unsigned inputMapIdx = 0 , outputMapIdx = 2 ;
663-
664688 *dilations = SmallVector<int64_t >(2 , 1 );
665689 *strides = SmallVector<int64_t >(2 , 1 );
666- // Match: N
667- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
668- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
669- return false ;
670- // Match: H + h
690+ MLIRContext *context = op->getContext ();
691+ AffineExpr N = getAffineDimExpr (0 , context);
692+ AffineExpr H = getAffineDimExpr (1 , context);
693+ AffineExpr W = getAffineDimExpr (2 , context);
694+ AffineExpr C = getAffineDimExpr (3 , context);
695+ AffineExpr h = getAffineDimExpr (4 , context);
696+ AffineExpr w = getAffineDimExpr (5 , context);
697+ // First fetch dilations/strides :-
698+ // Match: H * stride + h * dilation
671699 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
672700 /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
673701 return false ;
674- // Match: W + w
702+ // Match: W * stride + w * dilation
675703 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
676704 /* oDim=*/ 2 , (*dilations)[1 ], (*strides)[1 ]))
677705 return false ;
678- // Match: C
679- if (getAffineMapDim (indexingMaps, inputMapIdx, 3 ) !=
680- getAffineMapDim (indexingMaps, outputMapIdx, 3 ))
706+ // Match expected indexing maps
707+ if (!convLayoutMatches (
708+ {/* inputMap=*/ {N, H * (*strides)[0 ] + h * (*dilations)[0 ],
709+ W * (*strides)[1 ] + w * (*dilations)[1 ], C},
710+ /* filterMap=*/ {h, w},
711+ /* outputMap=*/ {N, H, W, C}},
712+ indexingMaps, context))
681713 return false ;
682714 // Match body
683715 Block *body = op.getBlock ();
@@ -705,25 +737,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
705737 if (!verifyConvIndexingMapSizes (indexingMaps, {4 , 2 , 4 }))
706738 return false ;
707739
708- unsigned inputMapIdx = 0 , outputMapIdx = 2 ;
709-
710740 *dilations = SmallVector<int64_t >(2 , 1 );
711741 *strides = SmallVector<int64_t >(2 , 1 );
712- // Match: N
713- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
714- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
715- return false ;
716- // Match: H + h
742+ MLIRContext *context = op->getContext ();
743+ AffineExpr N = getAffineDimExpr (0 , context);
744+ AffineExpr H = getAffineDimExpr (1 , context);
745+ AffineExpr W = getAffineDimExpr (2 , context);
746+ AffineExpr C = getAffineDimExpr (3 , context);
747+ AffineExpr h = getAffineDimExpr (4 , context);
748+ AffineExpr w = getAffineDimExpr (5 , context);
749+ // First fetch dilations/strides :-
750+ // Match: H * stride + h * dilation
717751 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
718752 /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
719753 return false ;
720- // Match: W + w
754+ // Match: W * stride + w * dilation
721755 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
722756 /* oDim=*/ 2 , (*dilations)[1 ], (*strides)[1 ]))
723757 return false ;
724- // Match: C
725- if (getAffineMapDim (indexingMaps, inputMapIdx, 3 ) !=
726- getAffineMapDim (indexingMaps, outputMapIdx, 3 ))
758+ // Match expected indexing maps
759+ if (!convLayoutMatches (
760+ {/* inputMap=*/ {N, H * (*strides)[0 ] + h * (*dilations)[0 ],
761+ W * (*strides)[1 ] + w * (*dilations)[1 ], C},
762+ /* filterMap=*/ {h, w},
763+ /* outputMap=*/ {N, H, W, C}},
764+ indexingMaps, context))
727765 return false ;
728766 // Match body
729767 Block *body = op.getBlock ();
@@ -751,25 +789,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
751789 if (!verifyConvIndexingMapSizes (indexingMaps, {4 , 2 , 4 }))
752790 return false ;
753791
754- unsigned inputMapIdx = 0 , outputMapIdx = 2 ;
755-
756792 *dilations = SmallVector<int64_t >(2 , 1 );
757793 *strides = SmallVector<int64_t >(2 , 1 );
758- // Match: N
759- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
760- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
761- return false ;
762- // Match: H + h
794+ MLIRContext *context = op->getContext ();
795+ AffineExpr N = getAffineDimExpr (0 , context);
796+ AffineExpr H = getAffineDimExpr (1 , context);
797+ AffineExpr W = getAffineDimExpr (2 , context);
798+ AffineExpr C = getAffineDimExpr (3 , context);
799+ AffineExpr h = getAffineDimExpr (4 , context);
800+ AffineExpr w = getAffineDimExpr (5 , context);
801+ // First fetch dilations/strides :-
802+ // Match: H * stride + h * dilation
763803 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
764804 /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
765805 return false ;
766- // Match: W + w
806+ // Match: W * stride + w * dilation
767807 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
768808 /* oDim=*/ 2 , (*dilations)[1 ], (*strides)[1 ]))
769809 return false ;
770- // Match: C
771- if (getAffineMapDim (indexingMaps, inputMapIdx, 3 ) !=
772- getAffineMapDim (indexingMaps, outputMapIdx, 3 ))
810+ // Match expected indexing maps
811+ if (!convLayoutMatches (
812+ {/* inputMap=*/ {N, H * (*strides)[0 ] + h * (*dilations)[0 ],
813+ W * (*strides)[1 ] + w * (*dilations)[1 ], C},
814+ /* filterMap=*/ {h, w},
815+ /* outputMap=*/ {N, H, W, C}},
816+ indexingMaps, context))
773817 return false ;
774818 // Match body
775819 Block *body = op.getBlock ();
@@ -797,25 +841,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
797841 if (!verifyConvIndexingMapSizes (indexingMaps, {4 , 2 , 4 }))
798842 return false ;
799843
800- unsigned inputMapIdx = 0 , outputMapIdx = 2 ;
801-
802844 *dilations = SmallVector<int64_t >(2 , 1 );
803845 *strides = SmallVector<int64_t >(2 , 1 );
804- // Match: N
805- if (getAffineMapDim (indexingMaps, inputMapIdx, 0 ) !=
806- getAffineMapDim (indexingMaps, outputMapIdx, 0 ))
807- return false ;
808- // Match: H + h
846+ MLIRContext *context = op->getContext ();
847+ AffineExpr N = getAffineDimExpr (0 , context);
848+ AffineExpr H = getAffineDimExpr (1 , context);
849+ AffineExpr W = getAffineDimExpr (2 , context);
850+ AffineExpr C = getAffineDimExpr (3 , context);
851+ AffineExpr h = getAffineDimExpr (4 , context);
852+ AffineExpr w = getAffineDimExpr (5 , context);
853+ // First fetch dilations/strides :-
854+ // Match: H * stride + h * dilation
809855 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 ,
810856 /* oDim=*/ 1 , (*dilations)[0 ], (*strides)[0 ]))
811857 return false ;
812- // Match: W + w
858+ // Match: W * stride + w * dilation
813859 if (!matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 ,
814860 /* oDim=*/ 2 , (*dilations)[1 ], (*strides)[1 ]))
815861 return false ;
816- // Match: C
817- if (getAffineMapDim (indexingMaps, inputMapIdx, 3 ) !=
818- getAffineMapDim (indexingMaps, outputMapIdx, 3 ))
862+ // Match expected indexing maps
863+ if (!convLayoutMatches (
864+ {/* inputMap=*/ {N, H * (*strides)[0 ] + h * (*dilations)[0 ],
865+ W * (*strides)[1 ] + w * (*dilations)[1 ], C},
866+ /* filterMap=*/ {h, w},
867+ /* outputMap=*/ {N, H, W, C}},
868+ indexingMaps, context))
819869 return false ;
820870 // Match body
821871 Block *body = op.getBlock ();
0 commit comments