@@ -115,50 +115,50 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
115115// ===----------------------------------------------------------------------===//
116116
117117bool isaConv1DOp (LinalgOp op);
118- bool isaConv1DNwcWcfOp (LinalgOp op);
119- bool isaConv1DNcwFcwOp (LinalgOp op);
120- bool isaDepthwiseConv1DNcwCwOp (LinalgOp op);
121- bool isaDepthwiseConv1DNwcWcOp (LinalgOp op);
122- bool isaDepthwiseConv1DNwcWcmOp (LinalgOp op);
118+ bool isaConv1DNwcWcfOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
119+ bool isaConv1DNcwFcwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
120+ bool isaDepthwiseConv1DNcwCwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
121+ bool isaDepthwiseConv1DNwcWcOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
122+ bool isaDepthwiseConv1DNwcWcmOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
123123bool isaConv2DOp (LinalgOp op);
124- bool isaConv2DNhwcFhwcOp (LinalgOp op);
125- bool isaConv2DNhwcHwcfOp (LinalgOp op);
126- bool isaConv2DNchwFchwOp (LinalgOp op);
127- bool isaConv2DNhwcFhwcQOp (LinalgOp op);
128- bool isaConv2DNchwFchwQOp (LinalgOp op);
129- bool isaConv2DNgchwFgchwOp (LinalgOp op);
130- bool isaConv2DNgchwGfchwOp (LinalgOp op);
131- bool isaConv2DNhwcHwcfQOp (LinalgOp op);
132- bool isaConv2DNhwgcGfhwcQOp (LinalgOp op);
133- bool isaConv2DNgchwGfchwQOp (LinalgOp op);
134- bool isaConv2DNhwgcGfhwcOp (LinalgOp op);
135- bool isaDepthwiseConv2DNchwChwOp (LinalgOp op);
136- bool isaDepthwiseConv2DNhwcHwcOp (LinalgOp op);
137- bool isaDepthwiseConv2DNhwcHwcmOp (LinalgOp op);
138- bool isaDepthwiseConv2DNhwcHwcQOp (LinalgOp op);
139- bool isaDepthwiseConv2DNhwcHwcmQOp (LinalgOp op);
124+ bool isaConv2DNhwcFhwcOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
125+ bool isaConv2DNhwcHwcfOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
126+ bool isaConv2DNchwFchwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
127+ bool isaConv2DNhwcFhwcQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
128+ bool isaConv2DNchwFchwQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
129+ bool isaConv2DNgchwFgchwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
130+ bool isaConv2DNgchwGfchwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
131+ bool isaConv2DNhwcHwcfQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
132+ bool isaConv2DNhwgcGfhwcQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
133+ bool isaConv2DNgchwGfchwQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
134+ bool isaConv2DNhwgcGfhwcOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
135+ bool isaDepthwiseConv2DNchwChwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
136+ bool isaDepthwiseConv2DNhwcHwcOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
137+ bool isaDepthwiseConv2DNhwcHwcmOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
138+ bool isaDepthwiseConv2DNhwcHwcQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
139+ bool isaDepthwiseConv2DNhwcHwcmQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
140140bool isaConv3DOp (LinalgOp op);
141- bool isaConv3DNcdhwFcdhwOp (LinalgOp op);
142- bool isaConv3DNdhwcDhwcfOp (LinalgOp op);
143- bool isaConv3DNdhwcDhwcfQOp (LinalgOp op);
144- bool isaDepthwiseConv3DNdhwcDhwcmOp (LinalgOp op);
145- bool isaDepthwiseConv3DNcdhwCdhwOp (LinalgOp op);
146- bool isaDepthwiseConv3DNdhwcDhwcOp (LinalgOp op);
147- bool isaPoolingNchwMaxOp (LinalgOp op);
148- bool isaPoolingNchwSumOp (LinalgOp op);
149- bool isaPoolingNhwcMaxOp (LinalgOp op);
150- bool isaPoolingNhwcMinOp (LinalgOp op);
151- bool isaPoolingNhwcSumOp (LinalgOp op);
152- bool isaPoolingNhwcMaxUnsignedOp (LinalgOp op);
153- bool isaPoolingNhwcMinUnsignedOp (LinalgOp op);
154- bool isaPoolingNcwMaxOp (LinalgOp op);
155- bool isaPoolingNcwSumOp (LinalgOp op);
156- bool isaPoolingNwcMaxOp (LinalgOp op);
157- bool isaPoolingNwcMinOp (LinalgOp op);
158- bool isaPoolingNwcSumOp (LinalgOp op);
159- bool isaPoolingNdhwcMaxOp (LinalgOp op);
160- bool isaPoolingNdhwcMinOp (LinalgOp op);
161- bool isaPoolingNdhwcSumOp (LinalgOp op);
141+ bool isaConv3DNcdhwFcdhwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
142+ bool isaConv3DNdhwcDhwcfOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
143+ bool isaConv3DNdhwcDhwcfQOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
144+ bool isaDepthwiseConv3DNdhwcDhwcmOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
145+ bool isaDepthwiseConv3DNcdhwCdhwOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
146+ bool isaDepthwiseConv3DNdhwcDhwcOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
147+ bool isaPoolingNchwMaxOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
148+ bool isaPoolingNchwSumOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
149+ bool isaPoolingNhwcMaxOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
150+ bool isaPoolingNhwcMinOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
151+ bool isaPoolingNhwcSumOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
152+ bool isaPoolingNhwcMaxUnsignedOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
153+ bool isaPoolingNhwcMinUnsignedOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
154+ bool isaPoolingNcwMaxOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
155+ bool isaPoolingNcwSumOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
156+ bool isaPoolingNwcMaxOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
157+ bool isaPoolingNwcMinOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
158+ bool isaPoolingNwcSumOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
159+ bool isaPoolingNdhwcMaxOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
160+ bool isaPoolingNdhwcMinOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
161+ bool isaPoolingNdhwcSumOp (LinalgOp op, SmallVector< int64_t >* dilations = nullptr , SmallVector< int64_t >* strides = nullptr );
162162
163163// ===----------------------------------------------------------------------===//
164164// Fusion / Tiling utilities
0 commit comments