Skip to content

Commit b5c459d

Browse files
[Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (#163724)
-- This commit includes the basic infra/utilities to add matchers for linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it identifies which linalg.*conv*/*pool* op it is. -- It adds a few representative linalg.*conv*/*pool* ops to demo the matchers' capability and does so as part of `linalg-specialize-generic-ops` pass. -- The goal is directed towards addressing the aim of [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863) iteratively for `*conv*/*pooling*` ops. -- This is part-1 of a series of PRs aimed to add matchers for Convolution ops. -- For further details, refer to #163374 (review) Signed-off-by: Abhishek Varma <[email protected]>
1 parent 80ae168 commit b5c459d

File tree

4 files changed

+759
-0
lines changed

4 files changed

+759
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
102102
std::optional<SmallVector<ReassociationIndices>>
103103
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
104104

105+
//===----------------------------------------------------------------------===//
106+
// Convolution matcher utility
107+
//===----------------------------------------------------------------------===//
108+
109+
/// Given a linalg `op` this function returns true if it is a convolution op of
110+
/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred
111+
/// from the indexing maps.
112+
template <typename ConvOpTy>
113+
bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
114+
SmallVector<int64_t> *strides);
115+
105116
//===----------------------------------------------------------------------===//
106117
// Fusion / Tiling utilities
107118
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,51 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238
}
239239

240+
/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
241+
/// with `dilations` and `strides`.
242+
template <typename ConvOpTy>
243+
static FailureOr<LinalgOp>
244+
specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
245+
ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
246+
SmallVector<Value> inputs = genericOp.getDpsInputs();
247+
ValueRange outputs = genericOp.getDpsInits();
248+
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
249+
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
250+
? TypeRange(ValueRange(outputs))
251+
: TypeRange{};
252+
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
253+
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
254+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
255+
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
256+
return namedOp;
257+
}
258+
259+
/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
260+
static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
261+
GenericOp genericOp) {
262+
SmallVector<int64_t> dilations, strides;
263+
#define CONV_OP_SPECIALIZER(ConvOpTy) \
264+
if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
265+
return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
266+
strides); \
267+
// -----------------------------
268+
// Depthwise Convolution ops.
269+
// -----------------------------
270+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
271+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
272+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
273+
// -----------------------------
274+
// Pooling ops.
275+
// -----------------------------
276+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
277+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
278+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
279+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
280+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
281+
#undef CONV_OP_SPECIALIZER
282+
return failure();
283+
}
284+
240285
} // namespace
241286

242287
//===----------------------------------------------------------------------===//
@@ -316,6 +361,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
316361
if (isaContractionOpInterface(genericOp)) {
317362
return specializeLinalgContractions(rewriter, genericOp);
318363
}
364+
365+
// Convolution - e.g. *conv/pooling*
366+
if (isaConvolutionOpInterface(genericOp)) {
367+
return specializeLinalgConvolutions(rewriter, genericOp);
368+
}
319369
return failure();
320370
}
321371

0 commit comments

Comments
 (0)