@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238}
239239
240+ // / Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
241+ // / with `dilations` and `strides`.
242+ template <typename ConvOpTy>
243+ static FailureOr<LinalgOp>
244+ specializeToConvOp (RewriterBase &rewriter, GenericOp genericOp,
245+ ArrayRef<int64_t > dilations, ArrayRef<int64_t > strides) {
246+ SmallVector<Value> inputs = genericOp.getDpsInputs ();
247+ ValueRange outputs = genericOp.getDpsInits ();
248+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray ();
249+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics ()
250+ ? TypeRange (ValueRange (outputs))
251+ : TypeRange{};
252+ LinalgOp namedOp;
253+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
254+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
255+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
256+ namedOp = rewriter.replaceOpWithNewOp <ConvOpTy>(genericOp, resultTypes,
257+ inputs, outputs);
258+ } else {
259+ Attribute stridesAttr = rewriter.getI64TensorAttr (strides);
260+ Attribute dilationsAttr = rewriter.getI64TensorAttr (dilations);
261+ namedOp = rewriter.replaceOpWithNewOp <ConvOpTy>(
262+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
263+ }
264+ return namedOp;
265+ }
266+
267+ // / TODO(avarma): Convolution ops which rank-2 iteratory types array will be
268+ // / added here incrementally in follow-up PRs.
269+ static FailureOr<LinalgOp>
270+ inferAndSpecializeBasedOnRank2ConvIteratorTypes (RewriterBase &rewriter,
271+ GenericOp genericOp) {
272+ return failure ();
273+ }
274+
275+ static FailureOr<LinalgOp>
276+ inferAndSpecializeBasedOnRank4ConvIteratorTypes (RewriterBase &rewriter,
277+ GenericOp genericOp) {
278+ SmallVector<int64_t > dilations, strides;
279+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
280+ genericOp, &dilations, &strides))
281+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
282+ rewriter, genericOp, dilations, strides);
283+ return failure ();
284+ }
285+
286+ // / TODO(avarma): Convolution ops which rank-5 iteratory types array will be
287+ // / added here incrementally in follow-up PRs.
288+ static FailureOr<LinalgOp>
289+ inferAndSpecializeBasedOnRank5ConvIteratorTypes (RewriterBase &rewriter,
290+ GenericOp genericOp) {
291+ return failure ();
292+ }
293+
294+ static FailureOr<LinalgOp>
295+ inferAndSpecializeBasedOnRank6ConvIteratorTypes (RewriterBase &rewriter,
296+ GenericOp genericOp) {
297+ SmallVector<int64_t > dilations, strides;
298+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
299+ genericOp, &dilations, &strides))
300+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
301+ rewriter, genericOp, dilations, strides);
302+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
303+ &strides))
304+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
305+ dilations, strides);
306+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
307+ &strides))
308+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
309+ dilations, strides);
310+ if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
311+ &strides))
312+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
313+ dilations, strides);
314+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
315+ genericOp, &dilations, &strides))
316+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
317+ rewriter, genericOp, dilations, strides);
318+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
319+ genericOp, &dilations, &strides))
320+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
321+ rewriter, genericOp, dilations, strides);
322+ return failure ();
323+ }
324+
325+ // / TODO(avarma): Convolution ops which rank-7 iteratory types array will be
326+ // / added here incrementally in follow-up PRs.
327+ static FailureOr<LinalgOp>
328+ inferAndSpecializeBasedOnRank7ConvIteratorTypes (RewriterBase &rewriter,
329+ GenericOp genericOp) {
330+ return failure ();
331+ }
332+
333+ // / TODO(avarma): Convolution ops which rank-8 iteratory types array will be
334+ // / added here incrementally in follow-up PRs.
335+ static FailureOr<LinalgOp>
336+ inferAndSpecializeBasedOnRank8ConvIteratorTypes (RewriterBase &rewriter,
337+ GenericOp genericOp) {
338+ return failure ();
339+ }
340+
341+ static FailureOr<LinalgOp>
342+ inferAndSpecializeBasedOnRank9ConvIteratorTypes (RewriterBase &rewriter,
343+ GenericOp genericOp) {
344+ SmallVector<int64_t > dilations, strides;
345+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
346+ genericOp, &dilations, &strides))
347+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
348+ rewriter, genericOp, dilations, strides);
349+ return failure ();
350+ }
351+
352+ // Converts linalg.generic to named linalg.*conv/pooling* where possible. To
353+ // improve the search speed, the convolution ops have been segregated based on
354+ // the rank of iterator types array.
355+ static FailureOr<LinalgOp>
356+ inferAndSpecializeToConvolutionOp (RewriterBase &rewriter, GenericOp genericOp) {
357+ SmallVector<utils::IteratorType> iteratorTypes =
358+ genericOp.getIteratorTypesArray ();
359+ unsigned totalIterators = iteratorTypes.size ();
360+ switch (totalIterators) {
361+ case 2 :
362+ return inferAndSpecializeBasedOnRank2ConvIteratorTypes (rewriter, genericOp);
363+ case 4 :
364+ return inferAndSpecializeBasedOnRank4ConvIteratorTypes (rewriter, genericOp);
365+ case 5 :
366+ return inferAndSpecializeBasedOnRank5ConvIteratorTypes (rewriter, genericOp);
367+ case 6 :
368+ return inferAndSpecializeBasedOnRank6ConvIteratorTypes (rewriter, genericOp);
369+ case 7 :
370+ return inferAndSpecializeBasedOnRank7ConvIteratorTypes (rewriter, genericOp);
371+ case 8 :
372+ return inferAndSpecializeBasedOnRank8ConvIteratorTypes (rewriter, genericOp);
373+ case 9 :
374+ return inferAndSpecializeBasedOnRank9ConvIteratorTypes (rewriter, genericOp);
375+ }
376+ return failure ();
377+ }
378+
240379} // namespace
241380
242381// ===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
316455 if (isaContractionOpInterface (genericOp)) {
317456 return specializeLinalgContractions (rewriter, genericOp);
318457 }
458+
459+ // Convolution - e.g. *conv/pooling*
460+ if (isaConvolutionOpInterface (genericOp)) {
461+ return inferAndSpecializeToConvolutionOp (rewriter, genericOp);
462+ }
319463 return failure ();
320464}
321465
0 commit comments