Skip to content

Commit c6aea91

Browse files
[Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops
-- This commit includes the basic infra/utilities to add matchers for linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it identifies which linalg.*conv*/*pool* op it is. -- It adds a few representative linalg.*conv*/*pool* ops to demo the matchers' capability and does so as part of `linalg-specialize-generic-ops` pass. -- The goal is directed towards addressing the aim of [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863) iteratively for `*conv*/*pooling*` ops. -- This is part-1 of a series of PRs aimed to add matchers for Convolution ops. -- For further details, refer to llvm#163374 (review) Signed-off-by: Abhishek Varma <[email protected]>
1 parent 4f2c867 commit c6aea91

File tree

4 files changed

+767
-0
lines changed

4 files changed

+767
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
110110
std::optional<SmallVector<ReassociationIndices>>
111111
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
112112

113+
//===----------------------------------------------------------------------===//
114+
// Convolution matcher utility
115+
//===----------------------------------------------------------------------===//
116+
117+
template <typename ConvOpTy>
118+
bool isaConvolutionOpOfType(LinalgOp op,
119+
SmallVector<int64_t> *dilations = nullptr,
120+
SmallVector<int64_t> *strides = nullptr);
121+
113122
//===----------------------------------------------------------------------===//
114123
// Fusion / Tiling utilities
115124
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238
}
239239

240+
/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
241+
/// with `dilations` and `strides`.
242+
template <typename ConvOpTy>
243+
static FailureOr<LinalgOp>
244+
specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
245+
ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
246+
SmallVector<Value> inputs = genericOp.getDpsInputs();
247+
ValueRange outputs = genericOp.getDpsInits();
248+
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
249+
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
250+
? TypeRange(ValueRange(outputs))
251+
: TypeRange{};
252+
LinalgOp namedOp;
253+
if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
254+
std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
255+
std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
256+
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
257+
inputs, outputs);
258+
} else {
259+
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
260+
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
261+
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
262+
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
263+
}
264+
return namedOp;
265+
}
266+
267+
/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
268+
/// added here incrementally in follow-up PRs.
269+
static FailureOr<LinalgOp>
270+
inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
271+
GenericOp genericOp) {
272+
return failure();
273+
}
274+
275+
static FailureOr<LinalgOp>
276+
inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
277+
GenericOp genericOp) {
278+
SmallVector<int64_t> dilations, strides;
279+
if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
280+
genericOp, &dilations, &strides))
281+
return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
282+
rewriter, genericOp, dilations, strides);
283+
return failure();
284+
}
285+
286+
/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
287+
/// added here incrementally in follow-up PRs.
288+
static FailureOr<LinalgOp>
289+
inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
290+
GenericOp genericOp) {
291+
return failure();
292+
}
293+
294+
static FailureOr<LinalgOp>
295+
inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
296+
GenericOp genericOp) {
297+
SmallVector<int64_t> dilations, strides;
298+
if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
299+
genericOp, &dilations, &strides))
300+
return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
301+
rewriter, genericOp, dilations, strides);
302+
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
303+
&strides))
304+
return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
305+
dilations, strides);
306+
if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
307+
&strides))
308+
return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
309+
dilations, strides);
310+
if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
311+
&strides))
312+
return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
313+
dilations, strides);
314+
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
315+
genericOp, &dilations, &strides))
316+
return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
317+
rewriter, genericOp, dilations, strides);
318+
if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
319+
genericOp, &dilations, &strides))
320+
return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
321+
rewriter, genericOp, dilations, strides);
322+
return failure();
323+
}
324+
325+
/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
326+
/// added here incrementally in follow-up PRs.
327+
static FailureOr<LinalgOp>
328+
inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
329+
GenericOp genericOp) {
330+
return failure();
331+
}
332+
333+
/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
334+
/// added here incrementally in follow-up PRs.
335+
static FailureOr<LinalgOp>
336+
inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
337+
GenericOp genericOp) {
338+
return failure();
339+
}
340+
341+
static FailureOr<LinalgOp>
342+
inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
343+
GenericOp genericOp) {
344+
SmallVector<int64_t> dilations, strides;
345+
if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
346+
genericOp, &dilations, &strides))
347+
return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
348+
rewriter, genericOp, dilations, strides);
349+
return failure();
350+
}
351+
352+
// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
353+
// improve the search speed, the convolution ops have been segregated based on
354+
// the rank of iterator types array.
355+
static FailureOr<LinalgOp>
356+
inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
357+
SmallVector<utils::IteratorType> iteratorTypes =
358+
genericOp.getIteratorTypesArray();
359+
unsigned totalIterators = iteratorTypes.size();
360+
switch (totalIterators) {
361+
case 2:
362+
return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
363+
case 4:
364+
return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
365+
case 5:
366+
return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
367+
case 6:
368+
return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
369+
case 7:
370+
return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
371+
case 8:
372+
return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
373+
case 9:
374+
return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
375+
}
376+
return failure();
377+
}
378+
240379
} // namespace
241380

242381
//===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
316455
if (isaContractionOpInterface(genericOp)) {
317456
return specializeLinalgContractions(rewriter, genericOp);
318457
}
458+
459+
// Convolution - e.g. *conv/pooling*
460+
if (isaConvolutionOpInterface(genericOp)) {
461+
return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
462+
}
319463
return failure();
320464
}
321465

0 commit comments

Comments
 (0)