From c6aea9193db7ad415f2fcde00e8bdcc3d98cfea4 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 16 Oct 2025 02:54:14 -0500 Subject: [PATCH 1/7] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops -- This commit includes the basic infra/utilities to add matchers for linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it identifies which linalg.*conv*/*pool* op it is. -- It adds a few representative linalg.*conv*/*pool* ops to demo the matchers' capability and does so as part of `linalg-specialize-generic-ops` pass. -- The goal is directed towards addressing the aim of [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863) iteratively for `*conv*/*pooling*` ops. -- This is part-1 of a series of PRs aimed to add matchers for Convolution ops. -- For further details, refer to https://github.com/llvm/llvm-project/pull/163374#pullrequestreview-3341048722 Signed-off-by: Abhishek Varma --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 9 + .../Dialect/Linalg/Transforms/Specialize.cpp | 144 +++++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 502 ++++++++++++++++++ .../convolution/roundtrip-convolution.mlir | 112 ++++ 4 files changed, 767 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 48978eb7663d5..771d753a8bddb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +//===----------------------------------------------------------------------===// +// Convolution matcher utility +//===----------------------------------------------------------------------===// + +template +bool isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d68e358f..35861002e309e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,145 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template +static FailureOr +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef dilations, ArrayRef strides) { + SmallVector inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; +} + +/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To +// improve the search speed, the convolution ops have been segregated based on +// the rank of iterator types array. +static FailureOr +inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector iteratorTypes = + genericOp.getIteratorTypesArray(); + unsigned totalIterators = iteratorTypes.size(); + switch (totalIterators) { + case 2: + return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); + case 4: + return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); + case 5: + return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); + case 6: + return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); + case 7: + return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); + case 8: + return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); + case 9: + return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); + } + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +455,11 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return inferAndSpecializeToConvolutionOp(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 24d3722cf5426..c3c2819652129 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Utility to match block body for linalg.pool* ops. +template +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present(defOp) || ...)) + return false; + + BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); + BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + if (!lhsArg || !rhsArg) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, + uint32_t mapIndex, uint32_t dimIndex) { + auto affineMap = cast(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +// Check if `expr` is either: +// - a dimension expr alone (implying *1), or +// - a multiplication of dimension expr by constant. +static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, + int64_t &constantValue) { + if (auto dExpr = dyn_cast(expr)) { + dim = dExpr; + constantValue = 1; + return true; + } + + auto mulExpr = dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + if (auto dExpr = dyn_cast(lhs)) { + if (auto cst = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + if (auto cst = dyn_cast(lhs)) { + if (auto dExpr = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, CST_1 and CST_2 can be any constant. +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); + auto addExpr = dyn_cast(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0, c1; + + if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && + isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } else if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[aIndex].getResult(aDim) == +/// indexingMaps[bIndex].getResult(bDim) +static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, + unsigned aDim, unsigned bIndex, + unsigned bDim) { + return getAffineMapDim(indexingMaps, aIndex, aDim) == + getAffineMapDim(indexingMaps, bIndex, bDim); +} + +/// Give an array of AffineMaps, verify each map to be of the corresponding +/// `expectedSize`. +static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, + ArrayRef expectedSizes) { + if (indexingMaps.size() != expectedSizes.size()) + return false; + + for (auto [indexingMap, expectedSize] : + llvm::zip_equal(indexingMaps, expectedSizes)) { + auto affineMap = cast(indexingMap).getValue(); + if (affineMap.getNumResults() != expectedSize) + return false; + } + return true; +} + +/// Utility to update `dilations` and `strides` by copy the corresponding data +/// from `tempDilations` and `tempStrides`. +static bool updateConvDilationsAndStrides(SmallVector *dilations, + SmallVector *strides, + ArrayRef tempDilations, + ArrayRef tempStrides) { + if (!(dilations && strides)) + return true; + for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { + dilations->push_back(dilation); + strides->push_back(stride); + } + return true; +} + +static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], + tempStrides[1])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d0, d1, d2, d3, d8, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +template +bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); + } else { + return false; + } +} + +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir new file mode 100644 index 0000000000000..5a18ca8519be3 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -0,0 +1,112 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s + +func.func @depthwise_conv_1d_nwc_wc(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_nwc_wc +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned +// CHECK: linalg.pooling_nhwc_min_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic From cd1b88a9d7febdd1f933ac22254303f74643f1c2 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 17 Oct 2025 02:46:56 -0500 Subject: [PATCH 2/7] Review comment v1.0 --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 143 +++++++++--------- .../convolution/roundtrip-convolution.mlir | 16 +- 2 files changed, 87 insertions(+), 72 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c3c2819652129..4dfec7b361eab 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -418,9 +418,9 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector tempDilations(1, 1); SmallVector tempStrides(1, 1); - // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> - // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + // #map = affine_map<(N, W, C, w) -> (N, W + w, C)> + // #map1 = affine_map<(N, W, C, w) -> (w, C)> + // #map2 = affine_map<(N, W, C, w) -> (N, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && @@ -449,9 +449,9 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (C, h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && @@ -483,12 +483,12 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector tempDilations(3, 1); SmallVector tempStrides(3, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) - // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) - // -> (d5, d6, d7, d8, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) - // -> (d0, d1, d2, d3, d8, d4)> + // #map = affine_map<(N, D, H, W, CM, d, h, w, C) + // -> (N, D + d, H + h, W + w, C)> + // #map1 = affine_map<(N, D, H, W, CM, d, h, w, C) + // -> (d, h, w, C, CM)> + // #map2 = affine_map<(N, D, H, W, CM, d, h, w, C) + // -> (N, D, H, W, C, CM)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -526,9 +526,9 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -562,9 +562,9 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -598,9 +598,9 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -635,9 +635,9 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -672,9 +672,9 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -689,58 +689,61 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, tempDilations, tempStrides); } -template -bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - if constexpr (std::is_same_v) { - return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMaxOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMinOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcSumOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); - } else { - return false; - } +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); } -template bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcMaxOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcMinOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcSumOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); + SmallVector *strides) { + return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); +} Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 5a18ca8519be3..06c9a84049d81 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -99,14 +99,26 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tenso // ----- -func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { +func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filter: tensor, %init: tensor) -> tensor { %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins (%input, %filter: tensor, tensor) outs (%init: tensor) -> tensor return %0 : tensor } -// CHECK: @pooling_nhwc_min_unsigned +// CHECK: @pooling_nhwc_min_unsigned_integer // CHECK: linalg.pooling_nhwc_min_unsigned // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> // CHECK-NOT: linalg.generic + +func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned_float +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic From 0e9946b47c518867dae394c5221fba7d812c4803 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 22 Oct 2025 03:28:15 -0500 Subject: [PATCH 3/7] Review comment Hanhan v1.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 40 ------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 109 +++++------------- 2 files changed, 32 insertions(+), 117 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 35861002e309e..2bfa21d9062ee 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -264,14 +264,6 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, return namedOp; } -/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - static FailureOr inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { @@ -283,14 +275,6 @@ inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, return failure(); } -/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - static FailureOr inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { @@ -322,22 +306,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, return failure(); } -/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - -/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - static FailureOr inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { @@ -358,18 +326,10 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { genericOp.getIteratorTypesArray(); unsigned totalIterators = iteratorTypes.size(); switch (totalIterators) { - case 2: - return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); case 4: return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); - case 5: - return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); case 6: return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); - case 7: - return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); - case 8: - return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); case 9: return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 4dfec7b361eab..23c7fb68a5534 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -401,9 +401,10 @@ static bool updateConvDilationsAndStrides(SmallVector *dilations, return true; } -static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -432,9 +433,10 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -466,9 +468,10 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -507,8 +510,10 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -543,8 +548,10 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -579,8 +586,10 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -615,9 +624,10 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -652,9 +662,10 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -689,62 +700,6 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, tempDilations, tempStrides); } -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMaxOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMinOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcSumOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); -} - Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { From d44cc34ce67daccce72d930f6fea0982ce02a273 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 23 Oct 2025 06:04:50 -0500 Subject: [PATCH 4/7] Review comment Andrszej v2.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 54 ++++--------------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 6 +++ 2 files changed, 17 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 2bfa21d9062ee..ce3df6a485f92 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -264,25 +264,26 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, return namedOp; } +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To +// improve the search speed, the convolution ops have been segregated based on +// the rank of iterator types array. static FailureOr -inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { +inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; + // Depthwise Convolution ops. if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - return failure(); -} - -static FailureOr -inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - SmallVector dilations, strides; if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + // Pooling ops. if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, @@ -306,36 +307,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, return failure(); } -static FailureOr -inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - SmallVector dilations, strides; - if (isaConvolutionOpOfType( - genericOp, &dilations, &strides)) - return specializeToConvOp( - rewriter, genericOp, dilations, strides); - return failure(); -} - -// Converts linalg.generic to named linalg.*conv/pooling* where possible. To -// improve the search speed, the convolution ops have been segregated based on -// the rank of iterator types array. -static FailureOr -inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { - SmallVector iteratorTypes = - genericOp.getIteratorTypesArray(); - unsigned totalIterators = iteratorTypes.size(); - switch (totalIterators) { - case 4: - return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); - case 6: - return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); - case 9: - return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); - } - return failure(); -} - } // namespace //===----------------------------------------------------------------------===// @@ -417,10 +388,7 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, } // Convolution - e.g. *conv/pooling* - if (isaConvolutionOpInterface(genericOp)) { - return inferAndSpecializeToConvolutionOp(rewriter, genericOp); - } - return failure(); + return inferAndSpecializeToConvolutionOp(rewriter, genericOp); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 23c7fb68a5534..cd518fc38819e 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -263,6 +263,9 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { body); } +// max_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); @@ -273,6 +276,9 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { body); } +// min_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); From 7b47d9e56db22366604e8608d099002cba5e9fd6 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 24 Oct 2025 02:58:28 -0500 Subject: [PATCH 5/7] Review comment Andrszej v3.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 13 +++++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 39 ++++++++++--------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index ce3df6a485f92..c68f7bd88c1ae 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -267,10 +267,12 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, // Converts linalg.generic to named linalg.*conv/pooling* where possible. To // improve the search speed, the convolution ops have been segregated based on // the rank of iterator types array. -static FailureOr -inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; + // ----------------------------- // Depthwise Convolution ops. + //------------------------------ if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( @@ -283,7 +285,9 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); + // ----------------------------- // Pooling ops. + //------------------------------ if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, @@ -388,7 +392,10 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, } // Convolution - e.g. *conv/pooling* - return inferAndSpecializeToConvolutionOp(rewriter, genericOp); + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } + return failure(); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index cd518fc38819e..c5c9e4b2f8387 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -265,7 +265,7 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { // max_unsigned ops should not allow float data type. // TODO: Retire OPDSL logic. Refer to : -// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 +// https://github.com/llvm/llvm-project/issues/164800 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); @@ -278,7 +278,7 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { // min_unsigned ops should not allow float data type. // TODO: Retire OPDSL logic. Refer to : -// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 +// https://github.com/llvm/llvm-project/issues/164800 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); @@ -407,6 +407,9 @@ static bool updateConvDilationsAndStrides(SmallVector *dilations, return true; } +// --------------------------------------------- +// Matchers for specific convolution operation. +//---------------------------------------------- template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -414,8 +417,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) @@ -446,8 +449,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) @@ -481,8 +484,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) @@ -523,8 +526,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -561,8 +564,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -599,8 +602,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -637,8 +640,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -675,8 +678,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) From 47b3e34dc9f6bb882a8d91df0bd09fa2f8c684d3 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 27 Oct 2025 04:02:50 -0500 Subject: [PATCH 6/7] Review comment Andrzej v4.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 6 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 237 ++++++++++++------ 2 files changed, 158 insertions(+), 85 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index c68f7bd88c1ae..0b3662c888010 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,7 +237,7 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } -/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` +/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy` /// with `dilations` and `strides`. template static FailureOr @@ -272,7 +272,7 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, SmallVector dilations, strides; // ----------------------------- // Depthwise Convolution ops. - //------------------------------ + // ----------------------------- if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( @@ -287,7 +287,7 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, rewriter, genericOp, dilations, strides); // ----------------------------- // Pooling ops. - //------------------------------ + // ----------------------------- if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c5c9e4b2f8387..0be2668a9b346 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -244,6 +244,46 @@ bool isReductionIterator(utils::IteratorType iteratorType) { // Convolution matcher utilities //===----------------------------------------------------------------------===// +/// Returns the BlockArgument that leads to `val`, if any. Traverses optional +/// ext* ops. +static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { + BlockArgument blockArg; + if (!(blockArg = dyn_cast(val))) { + Operation *defOp = val.getDefiningOp(); + if (!dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp)) { + return nullptr; + } + blockArg = dyn_cast(defOp->getOperand(0)); + } + return blockArg; +} + +/// Utility to match block body for matmul-like ops. +static bool bodyMatcherForMatmulLikeOps(Value yieldVal, Block *body) { + Operation *addOp = yieldVal.getDefiningOp(); + if (!isa_and_present(addOp)) + return false; + + Operation *mulOp = addOp->getOperand(1).getDefiningOp(); + if (!isa_and_present(mulOp)) + return false; + + BlockArgument lhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + BlockArgument rhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || + lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || + rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) + return false; + return true; +} + /// Utility to match block body for linalg.pool* ops. template static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { @@ -253,7 +293,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); - if (!lhsArg || !rhsArg) + if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || + rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || + rhsArg.getArgNumber() != 0) return false; return true; } @@ -339,8 +381,9 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride) { - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); + unsigned inputMapIdx = 0, filterMapIdx = 1, + outputMapIdx = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); auto addExpr = dyn_cast(inpExpr); if (!addExpr || addExpr.getKind() != AffineExprKind::Add) return false; @@ -351,8 +394,8 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { // Pattern matched with dims and constants extracted. - AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); - AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); + AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); if (dim0 == fExpr && dim1 == oExpr) { dilation = c0; stride = c1; @@ -394,22 +437,26 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, /// Utility to update `dilations` and `strides` by copy the corresponding data /// from `tempDilations` and `tempStrides`. -static bool updateConvDilationsAndStrides(SmallVector *dilations, +static void updateConvDilationsAndStrides(SmallVector *dilations, SmallVector *strides, ArrayRef tempDilations, ArrayRef tempStrides) { if (!(dilations && strides)) - return true; + return; for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { dilations->push_back(dilation); strides->push_back(stride); } - return true; + return; } // --------------------------------------------- // Matchers for specific convolution operation. -//---------------------------------------------- +// --------------------------------------------- + +// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, w) -> (w, C)> +// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -424,24 +471,30 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) return false; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; SmallVector tempDilations(1, 1); SmallVector tempStrides(1, 1); - // #map = affine_map<(N, W, C, w) -> (N, W + w, C)> - // #map1 = affine_map<(N, W, C, w) -> (w, C)> - // #map2 = affine_map<(N, W, C, w) -> (N, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, filterMapIdx, 1) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, outputMapIdx, 2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], - tempStrides[0])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + tempStrides[0]) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -456,27 +509,36 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) return false; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (C, h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, filterMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, outputMapIdx, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], - tempStrides[1])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + tempStrides[1]) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (d, h, w, C, CM)> +// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D, H, W, C, CM)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -491,18 +553,15 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) return false; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; SmallVector tempDilations(3, 1); SmallVector tempStrides(3, 1); - // #map = affine_map<(N, D, H, W, CM, d, h, w, C) - // -> (N, D + d, H + h, W + w, C)> - // #map1 = affine_map<(N, D, H, W, CM, d, h, w, C) - // -> (d, h, w, C, CM)> - // #map2 = affine_map<(N, D, H, W, CM, d, h, w, C) - // -> (N, D, H, W, C, CM)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && @@ -512,13 +571,20 @@ bool isaConvolutionOpOfType( matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && - matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, filterMapIdx, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, outputMapIdx, 4) && + matchConvDimExprPattern(indexingMaps, filterMapIdx, 4, outputMapIdx, + 5) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -536,27 +602,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -574,27 +642,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMinSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -612,27 +682,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForSumPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -650,27 +722,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -688,25 +762,24 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, From ab8eb8f5354aa0d3436f47cabfacd228c5cc5ea4 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 4 Nov 2025 04:09:18 -0600 Subject: [PATCH 7/7] Doc comment + function signature change --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 73 ++++++++++++++----------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 0be2668a9b346..53669542cdb91 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -338,46 +338,53 @@ static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, return nullptr; } -// Check if `expr` is either: -// - a dimension expr alone (implying *1), or -// - a multiplication of dimension expr by constant. -static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, - int64_t &constantValue) { - if (auto dExpr = dyn_cast(expr)) { - dim = dExpr; - constantValue = 1; - return true; - } +/// Check if `expr` is either: +/// - a dimension expr alone (implying multiplication by 1), or +/// - a multiplication of dimension expr by any positive constant != 1 +/// In both cases we will capture the dimension expression into `dim` and +/// return the constant multiplier. Returns -1 in case of a match failure. +static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) { + if ((dim = dyn_cast(expr))) + return 1; auto mulExpr = dyn_cast(expr); if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) - return false; + return -1; AffineExpr lhs = mulExpr.getLHS(); AffineExpr rhs = mulExpr.getRHS(); - if (auto dExpr = dyn_cast(lhs)) { - if (auto cst = dyn_cast(rhs)) { - dim = dExpr; - constantValue = cst.getValue(); - return true; - } - } - if (auto cst = dyn_cast(lhs)) { - if (auto dExpr = dyn_cast(rhs)) { - dim = dExpr; - constantValue = cst.getValue(); - return true; - } - } - return false; + AffineConstantExpr cst = nullptr; + if (((dim = dyn_cast(lhs)) && + (cst = dyn_cast(rhs))) || + ((dim = dyn_cast(rhs)) && + (cst = dyn_cast(lhs)))) + return cst.getValue(); + return -1; } -/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// Given an array of AffineMaps `indexingMaps` verify the following +/// commutatively:- /// indexingMaps[0].getResult(iDim) == -/// indexingMaps[1].getResult(fDim) * + -/// indexingMaps[n-1].getResult(oDim) * -/// where, CST_1 and CST_2 can be any constant. +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, +/// - c0 and c1 can be any constant, +/// - n is the size of the indexingMaps' array, +/// - 0, 1 and n-1 are input, filter and output map indices respectively, +/// - iDim, fDim and oDim are the input, filter and output dimension +/// indices in their respective indexing maps +/// Example: +/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) +/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)> +/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +/// +/// Here, +/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3 +/// Therefore, +/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride) +/// would return true and update dilation = 3 and stride = 2 static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride) { @@ -389,10 +396,10 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, return false; AffineExpr dim0, dim1; - int64_t c0, c1; + int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0); + int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1); - if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && - isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + if (c0 != -1 && c1 != -1) { // Pattern matched with dims and constants extracted. AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);