@@ -736,6 +736,123 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
736736 matchConvDimExprPattern (indexingMaps, fIndex , 3 , oIndex, 4 ));
737737}
738738
739+ bool isaConv3DOp (LinalgOp op) {
740+ if (isa<linalg::Conv1DOp>(op)) return true ;
741+
742+ if (!isaConvolutionOpInterface (op)) return false ;
743+
744+ ArrayAttr indexingMaps = op.getIndexingMaps ();
745+ if (!verifyConvIndexingMapSizes (indexingMaps, {3 ,3 ,3 })) return false ;
746+
747+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
748+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
749+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
750+ return (matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 0 , /* fDim=*/ 0 , /* oDim=*/ 0 ) &&
751+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 1 , /* oDim=*/ 1 ) &&
752+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 ));
753+ }
754+
755+ bool isaConv3DNcdhwFcdhwOp (LinalgOp op) {
756+ if (isa<linalg::Conv1DOp>(op)) return true ;
757+
758+ if (!isaConvolutionOpInterface (op)) return false ;
759+
760+ ArrayAttr indexingMaps = op.getIndexingMaps ();
761+ if (!verifyConvIndexingMapSizes (indexingMaps, {5 ,5 ,5 })) return false ;
762+
763+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
764+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
765+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
766+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
767+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
768+ matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 1 ) &&
769+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 ) &&
770+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 3 , /* oDim=*/ 3 ) &&
771+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 4 , /* fDim=*/ 4 , /* oDim=*/ 4 ) &&
772+ matchConvDimExprPattern (indexingMaps, fIndex , 0 , oIndex, 1 ));
773+ }
774+
775+ bool isaConv3DNdhwcDhwcfOp (LinalgOp op) {
776+ if (isa<linalg::Conv1DOp>(op)) return true ;
777+
778+ if (!isaConvolutionOpInterface (op)) return false ;
779+
780+ ArrayAttr indexingMaps = op.getIndexingMaps ();
781+ if (!verifyConvIndexingMapSizes (indexingMaps, {5 ,5 ,5 })) return false ;
782+
783+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
784+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
785+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
786+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
787+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
788+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
789+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
790+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 ) &&
791+ matchConvDimExprPattern (indexingMaps, iIndex, 4 , fIndex , 3 ) &&
792+ matchConvDimExprPattern (indexingMaps, fIndex , 4 , oIndex, 4 ));
793+ }
794+
795+ bool isaDepthwiseConv3DNdhwcDhwcmOp (LinalgOp op) {
796+ if (isa<linalg::Conv1DOp>(op)) return true ;
797+
798+ if (!isaConvolutionOpInterface (op)) return false ;
799+
800+ ArrayAttr indexingMaps = op.getIndexingMaps ();
801+ if (!verifyConvIndexingMapSizes (indexingMaps, {5 ,5 ,6 })) return false ;
802+
803+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
804+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
805+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
806+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
807+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
808+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
809+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
810+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 ) &&
811+ matchConvDimExprPattern (indexingMaps, iIndex, 4 , fIndex , 3 ) &&
812+ matchConvDimExprPattern (indexingMaps, iIndex, 4 , oIndex, 4 ) &&
813+ matchConvDimExprPattern (indexingMaps, fIndex , 4 , oIndex, 5 ));
814+ }
815+
816+ bool isaDepthwiseConv3DNcdhwCdhwOp (LinalgOp op) {
817+ if (isa<linalg::Conv1DOp>(op)) return true ;
818+
819+ if (!isaConvolutionOpInterface (op)) return false ;
820+
821+ ArrayAttr indexingMaps = op.getIndexingMaps ();
822+ if (!verifyConvIndexingMapSizes (indexingMaps, {5 ,4 ,5 })) return false ;
823+
824+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
825+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
826+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
827+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
828+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
829+ matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 0 ) &&
830+ matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
831+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
832+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 ) &&
833+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 4 , /* fDim=*/ 3 , /* oDim=*/ 4 ));
834+ }
835+
836+ bool isaDepthwiseConv3DNdhwcDhwcOp (LinalgOp op) {
837+ if (isa<linalg::Conv1DOp>(op)) return true ;
838+
839+ if (!isaConvolutionOpInterface (op)) return false ;
840+
841+ ArrayAttr indexingMaps = op.getIndexingMaps ();
842+ if (!verifyConvIndexingMapSizes (indexingMaps, {5 ,4 ,5 })) return false ;
843+
844+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
845+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
846+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
847+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
848+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
849+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
850+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
851+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 ) &&
852+ matchConvDimExprPattern (indexingMaps, iIndex, 4 , fIndex , 3 ) &&
853+ matchConvDimExprPattern (indexingMaps, iIndex, 4 , oIndex, 4 ));
854+ }
855+
739856Value makeComposedPadHighOp (OpBuilder &b, Location loc, RankedTensorType type,
740857 Value source, Value pad, bool nofold,
741858 ValueRange typeDynDims) {
0 commit comments