diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 48978eb7663d5..771d753a8bddb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +//===----------------------------------------------------------------------===// +// Convolution matcher utility +//===----------------------------------------------------------------------===// + +template +bool isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d68e358f..929904fa2c510 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,286 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template +static FailureOr +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef dilations, ArrayRef strides) { + SmallVector inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; +} + +static FailureOr +inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, + strides); + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, + strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, + strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To +// improve the search speed, the convolution ops have been segregated based on +// the rank of iterator types array. +static FailureOr +inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector iteratorTypes = + genericOp.getIteratorTypesArray(); + unsigned totalIterators = iteratorTypes.size(); + switch (totalIterators) { + case 2: + return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); + case 4: + return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); + case 5: + return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); + case 6: + return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); + case 7: + return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); + case 8: + return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); + case 9: + return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); + } + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +596,11 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return inferAndSpecializeToConvolutionOp(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 3593b5348d268..13235d99887a7 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,6 +240,1973 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Utility to match block body for linalg.pool* ops. +template +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present(defOp) || ...)) + return false; + + BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); + BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + if (!lhsArg || !rhsArg) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, + uint32_t mapIndex, uint32_t dimIndex) { + auto affineMap = cast(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +// Check if `expr` is either: +// - a dimension expr alone (implying *1), or +// - a multiplication of dimension expr by constant. +static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, + int64_t &constantValue) { + if (auto dExpr = dyn_cast(expr)) { + dim = dExpr; + constantValue = 1; + return true; + } + + auto mulExpr = dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + if (auto dExpr = dyn_cast(lhs)) { + if (auto cst = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + if (auto cst = dyn_cast(lhs)) { + if (auto dExpr = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, CST_1 and CST_2 can be any constant. +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); + auto addExpr = dyn_cast(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0, c1; + + if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && + isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } else if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[aIndex].getResult(aDim) == +/// indexingMaps[bIndex].getResult(bDim) +static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, + unsigned aDim, unsigned bIndex, + unsigned bDim) { + return getAffineMapDim(indexingMaps, aIndex, aDim) == + getAffineMapDim(indexingMaps, bIndex, bDim); +} + +/// Give an array of AffineMaps, verify each map to be of the corresponding +/// `expectedSize`. +static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, + ArrayRef expectedSizes) { + if (indexingMaps.size() != expectedSizes.size()) + return false; + + for (auto [indexingMap, expectedSize] : + llvm::zip_equal(indexingMaps, expectedSizes)) { + auto affineMap = cast(indexingMap).getValue(); + if (affineMap.getNumResults() != expectedSize) + return false; + } + return true; +} + +/// Utility to update `dilations` and `strides` by copy the corresponding data +/// from `tempDilations` and `tempStrides`. +static bool updateConvDilationsAndStrides(SmallVector *dilations, + SmallVector *strides, + ArrayRef tempDilations, + ArrayRef tempStrides) { + if (!(dilations && strides)) + return true; + for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { + dilations->push_back(dilation); + strides->push_back(stride); + } + return true; +} + +static bool isaConv1DOp(LinalgOp op) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {1, 1, 1})) + return false; + + // #map = affine_map<(d0, d1) -> (d0 + d1)> + // #map1 = affine_map<(d0, d1) -> (d1)> + // #map2 = affine_map<(d0, d1) -> (d0)> + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, tempDilations[0], + tempStrides[0]); +} + +static bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +// ------------------- +static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DOp(LinalgOp op) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {2, 2, 2})) + return false; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[1], + tempStrides[1])); +} + +static bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, + // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, + // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, + // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, + // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, + // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4) + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, + // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, + // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, + // d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, + // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], + tempStrides[1])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, + // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, + // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 0, 0, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv3DOp(LinalgOp op) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3})) + return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[2], + tempStrides[2])); +} + +static bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, + // d3 + d7, d4 + d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d1, d5, d6, d7, d8)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> ()> #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> + // (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> (d0, d1, d2, d3, d8, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + // + d5, d3 + d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, + // d4, d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d7, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, + /*oDim=*/4, tempDilations[2], + tempStrides[2])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + + // d5, d3 + d6, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d0, d1, d2, d3, d7)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNchwSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNcwSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +template +bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if constexpr (std::is_same_v) { + return isaConv1DOp(op); + } else if constexpr (std::is_same_v) { + return isaConv1DNwcWcfOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv1DNcwFcwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNcwCwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNwcWcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DOp(op); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcFhwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcHwcfOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNchwFchwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcFhwcQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNchwFchwQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNgchwFgchwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNgchwGfchwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcHwcfQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwgcGfhwcQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNgchwGfchwQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwgcGfhwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcmQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv3DOp(op); + } else if constexpr (std::is_same_v) { + return isaConv3DNcdhwFcdhwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv3DNdhwcDhwcfOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv3DNdhwcDhwcfQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNcdhwCdhwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNdhwcDhwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNchwMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNchwSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNcwMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNcwSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNwcSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNdhwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNdhwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNdhwcSumOp(op, dilations, strides); + } else { + return false; + } +} + +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir new file mode 100644 index 0000000000000..8cd57044e613f --- /dev/null +++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir @@ -0,0 +1,615 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s + +func.func @conv_1d_nwc_wcf(%input: memref, %filter: memref, %output: memref) { + linalg.conv_1d_nwc_wcf {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_1d_nwc_wcf +// CHECK: linalg.conv_1d_nwc_wcf +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_1d_ncw_fcw(%input: memref, %filter: memref, %output: memref) { + linalg.conv_1d_ncw_fcw {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_1d_ncw_fcw +// CHECK: linalg.conv_1d_ncw_fcw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_1d(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv_1d ins(%in, %filter : memref, memref) + outs(%out : memref) + return +} +// CHECK: @conv_1d +// CHECK: linalg.conv_1d +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_1d_ncw_cw(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_ncw_cw {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_ncw_cw +// CHECK: linalg.depthwise_conv_1d_ncw_cw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_1d_nwc_wc(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_nwc_wc +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_1d_nwc_wcm(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_nwc_wcm +// CHECK: linalg.depthwise_conv_1d_nwc_wcm +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv_2d ins(%in, %filter : memref, memref) + outs(%out: memref) + return +} +// CHECK: @conv_2d +// CHECK: linalg.conv_2d +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nchw_fchw(%arg0 : tensor, + %arg1 : tensor, %arg2 : tensor) -> + (tensor<4x8x12x16xf32>, tensor) { + %0 = linalg.conv_2d_nchw_fchw {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %1 = tensor.cast %0 : tensor to tensor<4x8x12x16xf32> + return %1, %0 : tensor<4x8x12x16xf32>, tensor +} +// CHECK: @conv_2d_nchw_fchw +// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-SAME: dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nchw_fchw_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %inputzp, %filterzp: tensor, tensor, i32, i32) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nchw_fchw_q +// CHECK: linalg.conv_2d_nchw_fchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_ngchw_fgchw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_fgchw +// CHECK: linalg.conv_2d_ngchw_fgchw +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_ngchw_gfchw(%input: memref, %filter: memref, %output: memref) { + linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_2d_ngchw_gfchw +// CHECK: linalg.conv_2d_ngchw_gfchw +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_ngchw_gfchw_q(%input: memref, %filter: memref, %inputzp: i32, %filterzp: i32, %output: memref) { + linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %inputzp, %filterzp: memref, memref, i32, i32) + outs (%output: memref) + return +} +// CHECK: @conv_2d_ngchw_gfchw_q +// CHECK: linalg.conv_2d_ngchw_gfchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_hwcf_q(%input: memref, %filter: memref, %inputzp: i32, %filterzp: i32, %output: memref) { + linalg.conv_2d_nhwc_hwcf_q { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter, %inputzp, %filterzp : memref, memref, i32, i32) outs(%output : memref) + return +} +// CHECK: @conv_2d_nhwc_hwcf_q +// CHECK: linalg.conv_2d_nhwc_hwcf_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwgc_gfhwc_q(%input: memref, %filter: memref, %inputzp: i32, %filterzp: i32, %output: memref) { + linalg.conv_2d_nhwgc_gfhwc_q { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter, %inputzp, %filterzp : memref, memref, i32, i32) outs(%output : memref) + return +} +// CHECK: @conv_2d_nhwgc_gfhwc_q +// CHECK: linalg.conv_2d_nhwgc_gfhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %output: tensor) -> tensor{ + %res = linalg.depthwise_conv_2d_nhwc_hwc_q { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter, %inputzp, %filterzp : tensor, tensor, i32, i32) outs(%output : tensor) -> tensor + return %res : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_fhwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc +// CHECK: linalg.conv_2d_nhwc_fhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_fhwc_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %inputzp, %filterzp: tensor, tensor, i32, i32) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc_q +// CHECK: linalg.conv_2d_nhwc_fhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_hwcf(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_hwcf +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwgc_gfhwc(%input: memref, %filter: memref, %output: memref) { + linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_2d_nhwgc_gfhwc +// CHECK: linalg.conv_2d_nhwgc_gfhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv_3d ins(%in, %filter : memref, memref) + outs(%out : memref) + return +} +// CHECK: @conv_3d +// CHECK: linalg.conv_3d +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d_ncdhw_fcdhw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ncdhw_fcdhw +// CHECK: linalg.conv_3d_ncdhw_fcdhw +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d_ndhwc_dhwcf(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ndhwc_dhwcf +// CHECK: linalg.conv_3d_ndhwc_dhwcf +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d_ndhwc_dhwcf_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %init: tensor) -> tensor { + %0 = linalg.conv_3d_ndhwc_dhwcf_q {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins(%input, %filter, %inputzp, %filterzp : tensor, tensor, i32, i32) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ndhwc_dhwcf_q +// CHECK: linalg.conv_3d_ndhwc_dhwcf_q +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ncdhw_cdhw +// CHECK: linalg.depthwise_conv_3d_ncdhw_cdhw +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwc +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwc +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nchw_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nchw_max +// CHECK: linalg.pooling_nchw_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nchw_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nchw_sum +// CHECK: linalg.pooling_nchw_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ncw_max(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_ncw_max {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ncw_max +// CHECK: linalg.pooling_ncw_max +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ncw_sum(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_ncw_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ncw_sum +// CHECK: linalg.pooling_ncw_sum +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ndhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_max +// CHECK: linalg.pooling_ndhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ndhwc_min(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_min +// CHECK: linalg.pooling_ndhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ndhwc_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_sum +// CHECK: linalg.pooling_ndhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nwc_max(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_max +// CHECK: linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nwc_min(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_nwc_min {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_min +// CHECK: linalg.pooling_nwc_min +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nwc_sum(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_nwc_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_sum +// CHECK: linalg.pooling_nwc_sum +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic