@@ -614,6 +614,48 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
614614 matchConvDimExprPattern (indexingMaps, fIndex , 1 , oIndex, 2 ));
615615}
616616
617+ bool isaConv2DNhwcHwcfQOp (LinalgOp op) {
618+ if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true ;
619+
620+ if (!isaConvolutionOpInterface (op)) return false ;
621+
622+ ArrayAttr indexingMaps = op.getIndexingMaps ();
623+ if (!verifyConvIndexingMapSizes (indexingMaps, {4 ,4 ,0 ,0 ,4 })) return false ;
624+
625+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 4 ;
626+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
627+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
628+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
629+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
630+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
631+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
632+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
633+ matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 2 ) &&
634+ matchConvDimExprPattern (indexingMaps, fIndex , 3 , oIndex, 3 ));
635+ }
636+
637+ bool isaConv2DNhwgcGfhwcQOp (LinalgOp op) {
638+ if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true ;
639+
640+ if (!isaConvolutionOpInterface (op)) return false ;
641+
642+ ArrayAttr indexingMaps = op.getIndexingMaps ();
643+ if (!verifyConvIndexingMapSizes (indexingMaps, {5 ,5 ,0 ,0 ,5 })) return false ;
644+
645+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 4 ;
646+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
647+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
648+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
649+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
650+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
651+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 2 , /* oDim=*/ 1 ) &&
652+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 3 , /* oDim=*/ 2 ) &&
653+ matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 0 ) &&
654+ matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 ) &&
655+ matchConvDimExprPattern (indexingMaps, iIndex, 4 , fIndex , 4 ) &&
656+ matchConvDimExprPattern (indexingMaps, fIndex , 1 , oIndex, 4 ));
657+ }
658+
617659bool isaConv2DNgchwGfchwQOp (LinalgOp op) {
618660 if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true ;
619661
@@ -736,6 +778,26 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
736778 matchConvDimExprPattern (indexingMaps, fIndex , 3 , oIndex, 4 ));
737779}
738780
781+ bool isaDepthwiseConv2DNhwcHwcQOp (LinalgOp op) {
782+ if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true ;
783+
784+ if (!isaConvolutionOpInterface (op)) return false ;
785+
786+ ArrayAttr indexingMaps = op.getIndexingMaps ();
787+ if (!verifyConvIndexingMapSizes (indexingMaps, {4 ,3 ,0 ,0 ,4 })) return false ;
788+
789+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 4 ;
790+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
791+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
792+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
793+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
794+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
795+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
796+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
797+ matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 2 ) &&
798+ matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 ));
799+ }
800+
739801bool isaConv3DOp (LinalgOp op) {
740802 if (isa<linalg::Conv3DOp>(op)) return true ;
741803
@@ -792,6 +854,27 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
792854 matchConvDimExprPattern (indexingMaps, fIndex , 4 , oIndex, 4 ));
793855}
794856
857+ bool isaConv3DNdhwcDhwcfQOp (LinalgOp op) {
858+ if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true ;
859+
860+ if (!isaConvolutionOpInterface (op)) return false ;
861+
862+ ArrayAttr indexingMaps = op.getIndexingMaps ();
863+ if (!verifyConvIndexingMapSizes (indexingMaps, {5 ,5 ,0 ,0 ,5 })) return false ;
864+
865+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 4 ;
866+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
867+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
868+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
869+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
870+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
871+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
872+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
873+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 ) &&
874+ matchConvDimExprPattern (indexingMaps, iIndex, 4 , fIndex , 3 ) &&
875+ matchConvDimExprPattern (indexingMaps, fIndex , 4 , oIndex, 4 ));
876+ }
877+
795878bool isaDepthwiseConv3DNdhwcDhwcmOp (LinalgOp op) {
796879 if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true ;
797880
0 commit comments