@@ -111,123 +111,13 @@ std::optional<SmallVector<ReassociationIndices>>
111111getReassociationMapForFoldingUnitDims (ArrayRef<OpFoldResult> mixedSizes);
112112
113113// ===----------------------------------------------------------------------===//
114- // Convolution matcher utilities
114+ // Convolution matcher utility
115115// ===----------------------------------------------------------------------===//
116116
117- bool isaConv1DOp (LinalgOp op);
118- bool isaConv1DNwcWcfOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
119- SmallVector<int64_t > *strides = nullptr );
120- bool isaConv1DNcwFcwOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
121- SmallVector<int64_t > *strides = nullptr );
122- bool isaDepthwiseConv1DNcwCwOp (LinalgOp op,
123- SmallVector<int64_t > *dilations = nullptr ,
124- SmallVector<int64_t > *strides = nullptr );
125- bool isaDepthwiseConv1DNwcWcOp (LinalgOp op,
126- SmallVector<int64_t > *dilations = nullptr ,
127- SmallVector<int64_t > *strides = nullptr );
128- bool isaDepthwiseConv1DNwcWcmOp (LinalgOp op,
129- SmallVector<int64_t > *dilations = nullptr ,
130- SmallVector<int64_t > *strides = nullptr );
131- bool isaConv2DOp (LinalgOp op);
132- bool isaConv2DNhwcFhwcOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
133- SmallVector<int64_t > *strides = nullptr );
134- bool isaConv2DNhwcHwcfOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
135- SmallVector<int64_t > *strides = nullptr );
136- bool isaConv2DNchwFchwOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
137- SmallVector<int64_t > *strides = nullptr );
138- bool isaConv2DNhwcFhwcQOp (LinalgOp op,
139- SmallVector<int64_t > *dilations = nullptr ,
140- SmallVector<int64_t > *strides = nullptr );
141- bool isaConv2DNchwFchwQOp (LinalgOp op,
142- SmallVector<int64_t > *dilations = nullptr ,
143- SmallVector<int64_t > *strides = nullptr );
144- bool isaConv2DNgchwFgchwOp (LinalgOp op,
145- SmallVector<int64_t > *dilations = nullptr ,
146- SmallVector<int64_t > *strides = nullptr );
147- bool isaConv2DNgchwGfchwOp (LinalgOp op,
148- SmallVector<int64_t > *dilations = nullptr ,
149- SmallVector<int64_t > *strides = nullptr );
150- bool isaConv2DNhwcHwcfQOp (LinalgOp op,
151- SmallVector<int64_t > *dilations = nullptr ,
152- SmallVector<int64_t > *strides = nullptr );
153- bool isaConv2DNhwgcGfhwcQOp (LinalgOp op,
117+ template <typename ConvOpTy>
118+ bool isaConvolutionOpOfType (LinalgOp op,
154119 SmallVector<int64_t > *dilations = nullptr ,
155120 SmallVector<int64_t > *strides = nullptr );
156- bool isaConv2DNgchwGfchwQOp (LinalgOp op,
157- SmallVector<int64_t > *dilations = nullptr ,
158- SmallVector<int64_t > *strides = nullptr );
159- bool isaConv2DNhwgcGfhwcOp (LinalgOp op,
160- SmallVector<int64_t > *dilations = nullptr ,
161- SmallVector<int64_t > *strides = nullptr );
162- bool isaDepthwiseConv2DNchwChwOp (LinalgOp op,
163- SmallVector<int64_t > *dilations = nullptr ,
164- SmallVector<int64_t > *strides = nullptr );
165- bool isaDepthwiseConv2DNhwcHwcOp (LinalgOp op,
166- SmallVector<int64_t > *dilations = nullptr ,
167- SmallVector<int64_t > *strides = nullptr );
168- bool isaDepthwiseConv2DNhwcHwcmOp (LinalgOp op,
169- SmallVector<int64_t > *dilations = nullptr ,
170- SmallVector<int64_t > *strides = nullptr );
171- bool isaDepthwiseConv2DNhwcHwcQOp (LinalgOp op,
172- SmallVector<int64_t > *dilations = nullptr ,
173- SmallVector<int64_t > *strides = nullptr );
174- bool isaDepthwiseConv2DNhwcHwcmQOp (LinalgOp op,
175- SmallVector<int64_t > *dilations = nullptr ,
176- SmallVector<int64_t > *strides = nullptr );
177- bool isaConv3DOp (LinalgOp op);
178- bool isaConv3DNcdhwFcdhwOp (LinalgOp op,
179- SmallVector<int64_t > *dilations = nullptr ,
180- SmallVector<int64_t > *strides = nullptr );
181- bool isaConv3DNdhwcDhwcfOp (LinalgOp op,
182- SmallVector<int64_t > *dilations = nullptr ,
183- SmallVector<int64_t > *strides = nullptr );
184- bool isaConv3DNdhwcDhwcfQOp (LinalgOp op,
185- SmallVector<int64_t > *dilations = nullptr ,
186- SmallVector<int64_t > *strides = nullptr );
187- bool isaDepthwiseConv3DNdhwcDhwcmOp (LinalgOp op,
188- SmallVector<int64_t > *dilations = nullptr ,
189- SmallVector<int64_t > *strides = nullptr );
190- bool isaDepthwiseConv3DNcdhwCdhwOp (LinalgOp op,
191- SmallVector<int64_t > *dilations = nullptr ,
192- SmallVector<int64_t > *strides = nullptr );
193- bool isaDepthwiseConv3DNdhwcDhwcOp (LinalgOp op,
194- SmallVector<int64_t > *dilations = nullptr ,
195- SmallVector<int64_t > *strides = nullptr );
196- bool isaPoolingNchwMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
197- SmallVector<int64_t > *strides = nullptr );
198- bool isaPoolingNchwSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
199- SmallVector<int64_t > *strides = nullptr );
200- bool isaPoolingNhwcMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
201- SmallVector<int64_t > *strides = nullptr );
202- bool isaPoolingNhwcMinOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
203- SmallVector<int64_t > *strides = nullptr );
204- bool isaPoolingNhwcSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
205- SmallVector<int64_t > *strides = nullptr );
206- bool isaPoolingNhwcMaxUnsignedOp (LinalgOp op,
207- SmallVector<int64_t > *dilations = nullptr ,
208- SmallVector<int64_t > *strides = nullptr );
209- bool isaPoolingNhwcMinUnsignedOp (LinalgOp op,
210- SmallVector<int64_t > *dilations = nullptr ,
211- SmallVector<int64_t > *strides = nullptr );
212- bool isaPoolingNcwMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
213- SmallVector<int64_t > *strides = nullptr );
214- bool isaPoolingNcwSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
215- SmallVector<int64_t > *strides = nullptr );
216- bool isaPoolingNwcMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
217- SmallVector<int64_t > *strides = nullptr );
218- bool isaPoolingNwcMinOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
219- SmallVector<int64_t > *strides = nullptr );
220- bool isaPoolingNwcSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
221- SmallVector<int64_t > *strides = nullptr );
222- bool isaPoolingNdhwcMaxOp (LinalgOp op,
223- SmallVector<int64_t > *dilations = nullptr ,
224- SmallVector<int64_t > *strides = nullptr );
225- bool isaPoolingNdhwcMinOp (LinalgOp op,
226- SmallVector<int64_t > *dilations = nullptr ,
227- SmallVector<int64_t > *strides = nullptr );
228- bool isaPoolingNdhwcSumOp (LinalgOp op,
229- SmallVector<int64_t > *dilations = nullptr ,
230- SmallVector<int64_t > *strides = nullptr );
231121
232122// ===----------------------------------------------------------------------===//
233123// Fusion / Tiling utilities
0 commit comments