Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);

//===----------------------------------------------------------------------===//
// Convolution matcher utility
//===----------------------------------------------------------------------===//

template <typename ConvOpTy>
bool isaConvolutionOpOfType(LinalgOp op,
SmallVector<int64_t> *dilations = nullptr,
SmallVector<int64_t> *strides = nullptr);

//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
Expand Down
144 changes: 144 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}

/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
/// with `dilations` and `strides`.
template <typename ConvOpTy>
static FailureOr<LinalgOp>
specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
SmallVector<Value> inputs = genericOp.getDpsInputs();
ValueRange outputs = genericOp.getDpsInits();
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
? TypeRange(ValueRange(outputs))
: TypeRange{};
LinalgOp namedOp;
if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
inputs, outputs);
} else {
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
}
return namedOp;
}

/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
/// added here incrementally in follow-up PRs.
static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
return failure();
}

static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
rewriter, genericOp, dilations, strides);
return failure();
}

/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
/// added here incrementally in follow-up PRs.
static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
return failure();
}

static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
rewriter, genericOp, dilations, strides);
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
&strides))
return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
dilations, strides);
if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
&strides))
return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
dilations, strides);
if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
&strides))
return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
dilations, strides);
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
rewriter, genericOp, dilations, strides);
if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
rewriter, genericOp, dilations, strides);
return failure();
}

/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
/// added here incrementally in follow-up PRs.
static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
return failure();
}

/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
/// added here incrementally in follow-up PRs.
static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
return failure();
}

static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
rewriter, genericOp, dilations, strides);
return failure();
}

// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
// improve the search speed, the convolution ops have been segregated based on
// the rank of iterator types array.
static FailureOr<LinalgOp>
inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
SmallVector<utils::IteratorType> iteratorTypes =
genericOp.getIteratorTypesArray();
unsigned totalIterators = iteratorTypes.size();
switch (totalIterators) {
case 2:
return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
case 4:
return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
case 5:
return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
case 6:
return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
case 7:
return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
case 8:
return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
case 9:
return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
}
return failure();
}

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}

// Convolution - e.g. *conv/pooling*
if (isaConvolutionOpInterface(genericOp)) {
return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
}
return failure();
}

Expand Down
Loading