From a964bb437e612c57ce741157977502ed1f7815fb Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 22 Sep 2025 12:16:48 +0000 Subject: [PATCH 01/18] [WIP] Generic to named Conv op support Signed-off-by: Abhishek Varma --- .../Dialect/Linalg/Transforms/Specialize.cpp | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d68e358f..4e9572ee7cb04 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,159 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +static bool matchingIteratorTypes(ArrayRef iteratorTypes, +ArrayRef expectedIteratorTypes) { + if (iteratorTypes.size() != expectedIteratorTypes.size()) return false; + for (auto [orig, expected] : llvm::zip_equal(iteratorTypes, expectedIteratorTypes)) { + if (orig != expected) return false; + } + return true; +} + +static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, + uint32_t mapIndex, uint32_t dimIndex) { + auto affineMap = cast(indexingMaps[mapIndex]).getValue(); + // uint32_t nResults = affineMap.getNumResults(); + // llvm::outs()< iteratorTypes = genericOp.getIteratorTypesArray(); + SmallVector expectedIteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::reduction + }; + + if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) + return "linalg.conv_1d"; + return ""; +} + +static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() != 3) return ""; + SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); + // Conv 1D + // depthwise_conv_1d_ncw_cw + // depthwise_conv_1d_nwc_wc + // ["parallel", "parallel", "parallel", "reduction"] + SmallVector expectedIteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction + }; + // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) { + if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)) + return "linalg.depthwise_conv_1d_ncw_cw"; + else if (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)) + return "linalg.depthwise_conv_1d_nwc_wc"; + } + + // + expectedIteratorTypes[2] = utils::IteratorType::reduction; + if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) { + return "linalg.conv_2d"; + } + return ""; +} + +static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() != 3) return ""; + SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); + // "parallel", "parallel", "parallel", "reduction", "reduction"] + SmallVector expectedIteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::reduction + }; + if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) + return "linalg.depthwise_conv_1d_nwc_wcm"; + + expectedIteratorTypes[3] = utils::IteratorType::reduction; + // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) { + if (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) + return "linalg.conv_1d_nwc_wcf"; + else if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)) + return "linalg.conv_1d_ncw_fcw"; + } + return ""; +} + +static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { + SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); + SmallVector expectedIteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::reduction + }; + if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) + return "linalg.conv_1d"; + return ""; +} + +static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { + SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); + SmallVector expectedIteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::reduction + }; + if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) + return "linalg.conv_1d"; + return ""; +} + +static std::string inferConvolutionKind(GenericOp genericOp) { + SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); + unsigned totalIterators = iteratorTypes.size(); + switch(totalIterators) { + case 2: + return inferBasedOnRank2ConvIteratorTypes(genericOp); + case 4: + return inferBasedOnRank4ConvIteratorTypes(genericOp); + case 5: + return inferBasedOnRank5ConvIteratorTypes(genericOp); + case 7: + return inferBasedOnRank7ConvIteratorTypes(genericOp); + case 8: + return inferBasedOnRank8ConvIteratorTypes(genericOp); + } + return ""; +} + +// Converts linalg.generic to named linalg.*conv* where possible. +static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { + std::string convKind = inferConvolutionKind(genericOp); + if (convKind == "") return failure(); + SmallVector inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + if (convKind == "linalg.conv_1d") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_1d_nwc_wcf") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_1d_ncw_fcw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_1d_ncw_cw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } + return namedOp; + + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +469,11 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv* + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } return failure(); } From 89b7190e79b16954b254acca9b5f4d8f6f7c9eb6 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 29 Sep 2025 17:27:30 +0000 Subject: [PATCH 02/18] Matching indexing maps --- .../Dialect/Linalg/Transforms/Specialize.cpp | 257 +++++++++++++----- 1 file changed, 187 insertions(+), 70 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 4e9572ee7cb04..84b080fd53535 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,33 +237,17 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } -static bool matchingIteratorTypes(ArrayRef iteratorTypes, -ArrayRef expectedIteratorTypes) { - if (iteratorTypes.size() != expectedIteratorTypes.size()) return false; - for (auto [orig, expected] : llvm::zip_equal(iteratorTypes, expectedIteratorTypes)) { - if (orig != expected) return false; - } - return true; -} - static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex) { auto affineMap = cast(indexingMaps[mapIndex]).getValue(); - // uint32_t nResults = affineMap.getNumResults(); - // llvm::outs()< iteratorTypes = genericOp.getIteratorTypesArray(); - SmallVector expectedIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::reduction - }; - - if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() != 3) return ""; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + if (getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) return "linalg.conv_1d"; return ""; } @@ -271,74 +255,187 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) { static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() != 3) return ""; - SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); - // Conv 1D + unsigned iIndex = 0, fIndex = 1, oIndex = 2; // depthwise_conv_1d_ncw_cw + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2)))) + return "linalg.depthwise_conv_1d_ncw_cw"; // depthwise_conv_1d_nwc_wc - // ["parallel", "parallel", "parallel", "reduction"] - SmallVector expectedIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::reduction - }; - // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; - if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) { - if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)) - return "linalg.depthwise_conv_1d_ncw_cw"; - else if (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)) - return "linalg.depthwise_conv_1d_nwc_wc"; - } - - // - expectedIteratorTypes[2] = utils::IteratorType::reduction; - if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) { + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)))) + return "linalg.depthwise_conv_1d_nwc_wc"; + // conv_2d + // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1)))) return "linalg.conv_2d"; - } return ""; } static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() != 3) return ""; - SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); - // "parallel", "parallel", "parallel", "reduction", "reduction"] - SmallVector expectedIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::reduction - }; - if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) - return "linalg.depthwise_conv_1d_nwc_wcm"; - - expectedIteratorTypes[3] = utils::IteratorType::reduction; - // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2; unsigned iIndex = 0, fIndex = 1, oIndex = 2; - if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) { - if (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) - return "linalg.conv_1d_nwc_wcf"; - else if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)) - return "linalg.conv_1d_ncw_fcw"; - } + // depthwise_conv_1d_nwc_wcm + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) && + (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3))) + return "linalg.depthwise_conv_1d_nwc_wcm"; + // conv_1d_nwc_wcf + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) && + (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) + return "linalg.conv_1d_nwc_wcf"; + // conv_1d_ncw_fcw + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + return "linalg.conv_1d_ncw_fcw"; return ""; } static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { - SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); - SmallVector expectedIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::reduction - }; - if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) - return "linalg.conv_1d"; + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() < 3) return ""; + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + // conv_2d_nhwc_fhwc + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + if (indexingMaps.size() == 3 && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) && + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3))) + return "linalg.conv_2d_nhwc_fhwc"; + // conv_2d_nhwc_hwcf + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) && + (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) + return "linalg.conv_2d_nhwc_hwcf"; + // conv_2d_nchw_fchw + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + if (indexingMaps.size() == 3 && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + return "linalg.conv_2d_nchw_fchw"; + // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps) + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + if (indexingMaps.size() == 5 && + (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) && + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3))) + return "linalg.conv_2d_nhwc_fhwc_q"; + // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps) + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + llvm::outs()<<"Indexing map size = "<(indexingMaps[2]).getValue().getNumResults() = "<(indexingMaps[2]).getValue().getNumResults()<<"\n"; + if (indexingMaps.size() == 5 && + (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + return "linalg.conv_2d_nchw_fchw_q"; return ""; } static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { - SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); - SmallVector expectedIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::reduction - }; - if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) - return "linalg.conv_1d"; + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() < 3) return ""; + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + // conv_2d_ngchw_fgchw + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + return "linalg.conv_2d_ngchw_fgchw"; + // conv_2d_ngchw_gfchw + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + if (indexingMaps.size() == 3 && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))) + return "linalg.conv_2d_ngchw_gfchw"; + // conv_2d_ngchw_gfchw_q + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + if (indexingMaps.size() == 5 && + (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))) + return "linalg.conv_2d_ngchw_gfchw_q"; + // conv_2d_nhwgc_gfhwc + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && + (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && + (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4))) + return "linalg.conv_2d_nhwgc_gfhwc"; return ""; } @@ -382,8 +479,28 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_1d_nwc_wcm") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.conv_2d") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nhwc_fhwc") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nhwc_hwcf") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nchw_fchw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nhwc_fhwc_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nchw_fchw_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_ngchw_fgchw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_ngchw_gfchw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_ngchw_gfchw_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } return namedOp; From dac92f142afdfd6a7a616574e1b0983513a37ba1 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 3 Oct 2025 06:56:34 -0500 Subject: [PATCH 03/18] Conv complete -> start Pool op now --- .../Dialect/Linalg/Transforms/Specialize.cpp | 146 +++++++++++++++++- 1 file changed, 142 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 84b080fd53535..6603967b991ab 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -316,6 +316,39 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { return ""; } +static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() < 3) return ""; + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + // depthwise_conv_2d_nchw_chw + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3)))) + return "linalg.depthwise_conv_2d_nchw_chw"; + // depthwise_conv_2d_nhwc_hwc + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) + return "linalg.depthwise_conv_2d_nhwc_hwc"; + // conv_3d + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2)))) + return "linalg.conv_3d"; + return ""; +} + static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() < 3) return ""; @@ -370,9 +403,6 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - llvm::outs()<<"Indexing map size = "<(indexingMaps[2]).getValue().getNumResults() = "<(indexingMaps[2]).getValue().getNumResults()<<"\n"; if (indexingMaps.size() == 5 && (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && @@ -381,6 +411,30 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) return "linalg.conv_2d_nchw_fchw_q"; + // depthwise_conv_2d_nhwc_hwcm + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> + if (indexingMaps.size() == 3 && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && + (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4))) + return "linalg.depthwise_conv_2d_nhwc_hwcm"; + // depthwise_conv_2d_nhwc_hwcm_q + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> + if (indexingMaps.size() == 5 && + (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && + (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && + (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4))) + return "linalg.depthwise_conv_2d_nhwc_hwcm_q"; return ""; } @@ -397,7 +451,7 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2))) return "linalg.conv_2d_ngchw_fgchw"; // conv_2d_ngchw_gfchw // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> @@ -436,6 +490,66 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4))) return "linalg.conv_2d_nhwgc_gfhwc"; + // depthwise_conv_3d_ncdhw_cdhw + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 4)))) + return "linalg.depthwise_conv_3d_ncdhw_cdhw"; + // depthwise_conv_3d_ndhwc_dhwc + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) + return "linalg.depthwise_conv_3d_ndhwc_dhwc"; + return ""; +} + +static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) { + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() < 3) return ""; + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + // conv_3d_ncdhw_fcdhw + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + return "linalg.conv_3d_ncdhw_fcdhw"; + // conv_3d_ndhwc_dhwcf + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) && + (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) + return "linalg.conv_3d_ndhwc_dhwcf"; + // depthwise_conv_3d_ndhwc_dhwcm + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) && + (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5))) + return "linalg.depthwise_conv_3d_ndhwc_dhwcm"; return ""; } @@ -449,10 +563,14 @@ static std::string inferConvolutionKind(GenericOp genericOp) { return inferBasedOnRank4ConvIteratorTypes(genericOp); case 5: return inferBasedOnRank5ConvIteratorTypes(genericOp); + case 6: + return inferBasedOnRank6ConvIteratorTypes(genericOp); case 7: return inferBasedOnRank7ConvIteratorTypes(genericOp); case 8: return inferBasedOnRank8ConvIteratorTypes(genericOp); + case 9: + return inferBasedOnRank9ConvIteratorTypes(genericOp); } return ""; } @@ -501,6 +619,26 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_3d") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } return namedOp; From 789fb856517377966b58ba9b1c4fdea1d1b12324 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 7 Oct 2025 09:04:37 -0500 Subject: [PATCH 04/18] Add pooling ops to the mix - has few issues but we can shift to considering dilations/strides now --- .../Dialect/Linalg/Transforms/Specialize.cpp | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 6603967b991ab..2efa410e4b855 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,39 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +/// Utility to match block body for linalg.pool* ops. +template +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + // if (!defOp) return false; + if (!(isa_and_present(defOp) || ...)) return false; + + BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); + BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + if (!lhsArg || !rhsArg) return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex) { auto affineMap = cast(indexingMaps[mapIndex]).getValue(); @@ -279,6 +312,39 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) && (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1)))) return "linalg.conv_2d"; + + Block *body = genericOp.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + // pooling_ncw_max + // pooling_ncw_sum + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> + // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2)))) { + if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) + return "linalg.pooling_ncw_max"; + if (bodyMatcherForSumPoolOps(yieldVal, body)) + return "linalg.pooling_ncw_sum"; + } + // pooling_nwc_max + // pooling_nwc_min + // pooling_nwc_sum + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) { + if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) + return "linalg.pooling_nwc_max"; + if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) + return "linalg.pooling_nwc_min"; + if (bodyMatcherForSumPoolOps(yieldVal, body)) + return "linalg.pooling_nwc_sum"; + } return ""; } @@ -346,6 +412,55 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) && (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2)))) return "linalg.conv_3d"; + + Block *body = genericOp.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + // pooling_nchw_max + // pooling_nchw_sum + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 3)))) { + if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) + return "linalg.pooling_nchw_max"; + if (bodyMatcherForSumPoolOps(yieldVal, body)) + return "linalg.pooling_nchw_sum"; + } + // pooling_nhwc_max + // pooling_nhwc_min + // pooling_nhwc_sum + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) { + if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) + return "linalg.pooling_nhwc_max"; + if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) + return "linalg.pooling_nhwc_min"; + if (bodyMatcherForSumPoolOps(yieldVal, body)) + return "linalg.pooling_nhwc_sum"; + } + // pooling_nhwc_max_unsigned + // pooling_nhwc_min_unsigned + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) { + if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) + return "linalg.pooling_nhwc_max_unsigned"; + if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) + return "linalg.pooling_nhwc_max_unsigned"; + } return ""; } @@ -510,6 +625,28 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) return "linalg.depthwise_conv_3d_ndhwc_dhwc"; + + Block *body = genericOp.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + // pooling_ndhwc_max + // pooling_ndhwc_min + // pooling_ndhwc_sum + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && + (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) { + if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) + return "linalg.pooling_ndhwc_max"; + if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) + return "linalg.pooling_ndhwc_min"; + if (bodyMatcherForSumPoolOps(yieldVal, body)) + return "linalg.pooling_ndhwc_sum"; + } return ""; } @@ -639,6 +776,36 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nchw_max") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nchw_sum") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nhwc_max") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nhwc_min") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nhwc_sum") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nhwc_max_unsigned") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nhwc_min_unsigned") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_ncw_max") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_ncw_sum") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nwc_max") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nwc_min") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_nwc_sum") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_ndhwc_max") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_ndhwc_min") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.pooling_ndhwc_sum") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } return namedOp; From 8e65bc6fc1a75de626a62ad72784087b57f2cd65 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 8 Oct 2025 03:42:40 -0500 Subject: [PATCH 05/18] Concisely v1.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 199 ++++++++++++------ 1 file changed, 130 insertions(+), 69 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 2efa410e4b855..01a5c3bebd146 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -273,14 +273,75 @@ static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex) { auto affineMap = cast(indexingMaps[mapIndex]).getValue(); - return affineMap.getResult(dimIndex); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +// Check if `expr` is either: +// - a dimension expr alone (implying *1), or +// - a multiplication of dimension expr by constant. +bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) { + if (auto dExpr = dyn_cast(expr)) { + dim = dExpr; + constantValue = 1; + return true; + } + + auto mulExpr = dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + if (auto dExpr = dyn_cast(lhs)) { + if (auto cst = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + if (auto cst = dyn_cast(lhs)) { + if (auto dExpr = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + return false; +} + +bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) { + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); + auto addExpr = dyn_cast(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + // TODO(Abhishek-Varma): Use this information in specialize.cpp. + int64_t c0, c1; + + if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && + isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); + return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr)); + } + return false; +} + +bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) { + return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim); } static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() != 3) return ""; unsigned iIndex = 0, fIndex = 1, oIndex = 2; - if (getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) + if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0)) return "linalg.conv_1d"; return ""; } @@ -295,7 +356,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2)))) + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2)) return "linalg.depthwise_conv_1d_ncw_cw"; // depthwise_conv_1d_nwc_wc // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> @@ -303,14 +364,14 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)))) + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1)) return "linalg.depthwise_conv_1d_nwc_wc"; // conv_2d // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1)))) + if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1)) return "linalg.conv_2d"; Block *body = genericOp.getBlock(); @@ -323,7 +384,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2)))) { + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_ncw_max"; if (bodyMatcherForSumPoolOps(yieldVal, body)) @@ -336,7 +397,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_nwc_max"; @@ -357,7 +418,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) && (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3))) return "linalg.depthwise_conv_1d_nwc_wcm"; @@ -366,7 +427,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) && (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) return "linalg.conv_1d_nwc_wcf"; @@ -376,7 +437,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) return "linalg.conv_1d_ncw_fcw"; return ""; @@ -392,25 +453,25 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3)))) + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3)) return "linalg.depthwise_conv_2d_nchw_chw"; // depthwise_conv_2d_nhwc_hwc // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) return "linalg.depthwise_conv_2d_nhwc_hwc"; // conv_3d // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2)))) + if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2)) return "linalg.conv_3d"; Block *body = genericOp.getBlock(); @@ -423,8 +484,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 3)))) { + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_nchw_max"; if (bodyMatcherForSumPoolOps(yieldVal, body)) @@ -437,8 +498,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_nhwc_max"; @@ -453,13 +514,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) { if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) return "linalg.pooling_nhwc_max_unsigned"; if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) - return "linalg.pooling_nhwc_max_unsigned"; + return "linalg.pooling_nhwc_min_unsigned"; } return ""; } @@ -474,8 +535,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (indexingMaps.size() == 3 && (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3))) return "linalg.conv_2d_nhwc_fhwc"; @@ -484,8 +545,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) && (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) return "linalg.conv_2d_nhwc_hwcf"; @@ -496,8 +557,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { if (indexingMaps.size() == 3 && (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) return "linalg.conv_2d_nchw_fchw"; // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps) @@ -508,8 +569,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { if (indexingMaps.size() == 5 && (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3))) return "linalg.conv_2d_nhwc_fhwc_q"; @@ -522,8 +583,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) return "linalg.conv_2d_nchw_fchw_q"; // depthwise_conv_2d_nhwc_hwcm @@ -532,8 +593,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> if (indexingMaps.size() == 3 && (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4))) return "linalg.depthwise_conv_2d_nhwc_hwcm"; @@ -545,8 +606,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { if (indexingMaps.size() == 5 && (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4))) return "linalg.depthwise_conv_2d_nhwc_hwcm_q"; @@ -564,8 +625,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && - (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2))) return "linalg.conv_2d_ngchw_fgchw"; // conv_2d_ngchw_gfchw @@ -576,8 +637,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && - (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))) return "linalg.conv_2d_ngchw_gfchw"; // conv_2d_ngchw_gfchw_q @@ -590,8 +651,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && - (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))) return "linalg.conv_2d_ngchw_gfchw_q"; // conv_2d_nhwgc_gfhwc @@ -599,8 +660,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 2))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) && (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4))) @@ -611,18 +672,18 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && - (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 4)))) + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4)) return "linalg.depthwise_conv_3d_ncdhw_cdhw"; // depthwise_conv_3d_ndhwc_dhwc // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) return "linalg.depthwise_conv_3d_ndhwc_dhwc"; @@ -636,9 +697,9 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_ndhwc_max"; @@ -660,9 +721,9 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) && - (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) return "linalg.conv_3d_ncdhw_fcdhw"; // conv_3d_ndhwc_dhwcf @@ -670,22 +731,22 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && - (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) && - (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) && + (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) return "linalg.conv_3d_ndhwc_dhwcf"; // depthwise_conv_3d_ndhwc_dhwcm // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)> if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) && - (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) && - (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) && - (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) && - (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5))) + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) && + (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5))) return "linalg.depthwise_conv_3d_ndhwc_dhwcm"; return ""; } From aae7e048da5f9fa34e7eb1a0de7559e823e23866 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 8 Oct 2025 06:23:27 -0500 Subject: [PATCH 06/18] Concise v2.0 --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 4 + .../Dialect/Linalg/Transforms/Specialize.cpp | 178 ++++++++++-------- 2 files changed, 100 insertions(+), 82 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 48978eb7663d5..46f9f10789e6c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -114,6 +114,10 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); // Fusion / Tiling utilities //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Fusion / Tiling utilities +//===----------------------------------------------------------------------===// + /// The type of loops to be generated during tiling. enum class LinalgTilingLoopType { Loops = 0, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 01a5c3bebd146..a741cd126dd3b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -354,16 +354,18 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2)) return "linalg.depthwise_conv_1d_ncw_cw"; // depthwise_conv_1d_nwc_wc // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1)) return "linalg.depthwise_conv_1d_nwc_wc"; // conv_2d @@ -382,8 +384,8 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_ncw_max"; @@ -396,9 +398,9 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) { + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2)) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_nwc_max"; if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) @@ -417,28 +419,29 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) && - (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3))) + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)) return "linalg.depthwise_conv_1d_nwc_wcm"; // conv_1d_nwc_wcf // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) && - (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2)) return "linalg.conv_1d_nwc_wcf"; // conv_1d_ncw_fcw // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) return "linalg.conv_1d_ncw_fcw"; return ""; } @@ -451,8 +454,9 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3)) return "linalg.depthwise_conv_2d_nchw_chw"; @@ -460,10 +464,11 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) return "linalg.depthwise_conv_2d_nhwc_hwc"; // conv_3d // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> @@ -482,8 +487,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) @@ -497,10 +502,10 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) { + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_nhwc_max"; if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) @@ -513,10 +518,10 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) { + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) { if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) return "linalg.pooling_nhwc_max_unsigned"; if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) @@ -534,32 +539,32 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (indexingMaps.size() == 3 && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3))) + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)) return "linalg.conv_2d_nhwc_fhwc"; // conv_2d_nhwc_hwcf // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) && - (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)) return "linalg.conv_2d_nhwc_hwcf"; // conv_2d_nchw_fchw // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (indexingMaps.size() == 3 && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) return "linalg.conv_2d_nchw_fchw"; // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps) // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> @@ -568,11 +573,11 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (indexingMaps.size() == 5 && (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3))) + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)) return "linalg.conv_2d_nhwc_fhwc_q"; // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps) // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> @@ -581,22 +586,23 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (indexingMaps.size() == 5 && (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) return "linalg.conv_2d_nchw_fchw_q"; // depthwise_conv_2d_nhwc_hwcm // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> if (indexingMaps.size() == 3 && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && - (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4))) + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)) return "linalg.depthwise_conv_2d_nhwc_hwcm"; // depthwise_conv_2d_nhwc_hwcm_q // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> @@ -605,11 +611,12 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> if (indexingMaps.size() == 5 && (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && - (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4))) + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)) return "linalg.depthwise_conv_2d_nhwc_hwcm_q"; return ""; } @@ -622,24 +629,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2))) + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2)) return "linalg.conv_2d_ngchw_fgchw"; // conv_2d_ngchw_gfchw // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if (indexingMaps.size() == 3 && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))) + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)) return "linalg.conv_2d_ngchw_gfchw"; // conv_2d_ngchw_gfchw_q // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> @@ -648,30 +657,33 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if (indexingMaps.size() == 5 && (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && - (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))) + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)) return "linalg.conv_2d_ngchw_gfchw_q"; // conv_2d_nhwgc_gfhwc // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) && - (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) && - (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) && - (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4))) + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)) return "linalg.conv_2d_nhwgc_gfhwc"; // depthwise_conv_3d_ncdhw_cdhw // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4)) @@ -680,11 +692,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) return "linalg.depthwise_conv_3d_ndhwc_dhwc"; Block *body = genericOp.getBlock(); @@ -696,11 +709,11 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) { + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) { if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return "linalg.pooling_ndhwc_max"; if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) @@ -719,34 +732,35 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && - (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))) + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) return "linalg.conv_3d_ncdhw_fcdhw"; // conv_3d_ndhwc_dhwcf // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) && - (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)) return "linalg.conv_3d_ndhwc_dhwcf"; // depthwise_conv_3d_ndhwc_dhwcm // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)> - if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) && + if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) && - (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5))) + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)) return "linalg.depthwise_conv_3d_ndhwc_dhwcm"; return ""; } From bafdb41a8e539d77adcf6253b1631d22c69b2d45 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 8 Oct 2025 07:08:18 -0500 Subject: [PATCH 07/18] Start pulling out into separate APIs --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 9 +- .../Dialect/Linalg/Transforms/Specialize.cpp | 34 +-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 219 ++++++++++++++++++ 3 files changed, 233 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 46f9f10789e6c..222b66ca51708 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -111,9 +111,16 @@ std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); //===----------------------------------------------------------------------===// -// Fusion / Tiling utilities +// Convolution matcher utilities //===----------------------------------------------------------------------===// +bool isaConv1DOp(LinalgOp op); +bool isaConv1DNwcWcfOp(LinalgOp op); +bool isaConv1DNcwFcwOp(LinalgOp op); +bool isaDepthwiseConv1DNcwCwOp(LinalgOp op); +bool isaDepthwiseConv1DNwcWcOp(LinalgOp op); +bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index a741cd126dd3b..8b51b8f13ce0d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -338,35 +338,24 @@ bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned a } static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) { - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() != 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; - if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0)) - return "linalg.conv_1d"; + if (isaConv1DOp(genericOp)) return "linalg.conv_1d"; return ""; } static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() != 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; // depthwise_conv_1d_ncw_cw // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2)) + if (isaDepthwiseConv1DNcwCwOp(genericOp)) return "linalg.depthwise_conv_1d_ncw_cw"; // depthwise_conv_1d_nwc_wc // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1)) + if (isaDepthwiseConv1DNwcWcOp(genericOp)) return "linalg.depthwise_conv_1d_nwc_wc"; // conv_2d // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> @@ -414,34 +403,23 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() != 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; // depthwise_conv_1d_nwc_wcm // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)) + if (isaDepthwiseConv1DNwcWcmOp(genericOp)) return "linalg.depthwise_conv_1d_nwc_wcm"; // conv_1d_nwc_wcf // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2)) + if (isaConv1DNwcWcfOp(genericOp)) return "linalg.conv_1d_nwc_wcf"; // conv_1d_ncw_fcw // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) + if (isaConv1DNcwFcwOp(genericOp)) return "linalg.conv_1d_ncw_fcw"; return ""; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 3593b5348d268..12f88caf08fc7 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,6 +240,225 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +// ------------------------------- +// ---------- CONV --------------- +// ------------------------------- + +/// Utility to match block body for linalg.pool* ops. +template +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + // if (!defOp) return false; + if (!(isa_and_present(defOp) || ...)) return false; + + BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); + BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + if (!lhsArg || !rhsArg) return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, + uint32_t mapIndex, uint32_t dimIndex) { + auto affineMap = cast(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +// Check if `expr` is either: +// - a dimension expr alone (implying *1), or +// - a multiplication of dimension expr by constant. +static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) { + if (auto dExpr = dyn_cast(expr)) { + dim = dExpr; + constantValue = 1; + return true; + } + + auto mulExpr = dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + if (auto dExpr = dyn_cast(lhs)) { + if (auto cst = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + if (auto cst = dyn_cast(lhs)) { + if (auto dExpr = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + return false; +} + +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) { + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); + auto addExpr = dyn_cast(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + // TODO(Abhishek-Varma): Use this information in specialize.cpp. + int64_t c0, c1; + + if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && + isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); + return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr)); + } + return false; +} + +static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) { + return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim); +} + +static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef expectedSizes) { + if (indexingMaps.size() != expectedSizes.size()) return false; + + for (auto [indexingMap, expectedSize] : llvm::zip_equal(indexingMaps, expectedSizes)) { + auto affineMap = cast(indexingMap).getValue(); + if (affineMap.getNumResults() != expectedSize) return false; + } + return true; +} + +bool isaConv1DOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {1,1,1})) return false; + + // #map = affine_map<(d0, d1) -> (d0 + d1)> + // #map1 = affine_map<(d0, d1) -> (d1)> + // #map2 = affine_map<(d0, d1) -> (d0)> + return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0); +} + +bool isaConv1DNwcWcfOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2)); +} + +bool isaConv1DNcwFcwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); +} + +bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2)); +} + +bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1)); +} + +bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> + // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { From a08247c4804f504535a061c77fe782e5b75991f4 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 8 Oct 2025 07:47:06 -0500 Subject: [PATCH 08/18] Some more APIs --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 14 + .../Dialect/Linalg/Transforms/Specialize.cpp | 105 +------ mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 277 ++++++++++++++++++ 3 files changed, 306 insertions(+), 90 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 222b66ca51708..b4955625b6dec 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -120,6 +120,20 @@ bool isaConv1DNcwFcwOp(LinalgOp op); bool isaDepthwiseConv1DNcwCwOp(LinalgOp op); bool isaDepthwiseConv1DNwcWcOp(LinalgOp op); bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op); +bool isaConv2DOp(LinalgOp op); +bool isaConv2DNhwcFhwcOp(LinalgOp op); +bool isaConv2DNhwcHwcfOp(LinalgOp op); +bool isaConv2DNchwFchwOp(LinalgOp op); +bool isaConv2DNhwcFhwcQOp(LinalgOp op); +bool isaConv2DNchwFchwQOp(LinalgOp op); +bool isaConv2DNgchwFgchwOp(LinalgOp op); +bool isaConv2DNgchwGfchwOp(LinalgOp op); +bool isaConv2DNgchwGfchwQOp(LinalgOp op); +bool isaConv2DNhwgcGfhwcOp(LinalgOp op); +bool isaDepthwiseConv2DNchwChwOp(LinalgOp op); +bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op); +bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op); +bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 8b51b8f13ce0d..968370c05615a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -361,10 +361,10 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> - if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1)) + if (isaConv2DOp(genericOp)) return "linalg.conv_2d"; + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; Block *body = genericOp.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); @@ -432,21 +432,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3)) + if (isaDepthwiseConv2DNchwChwOp(genericOp)) return "linalg.depthwise_conv_2d_nchw_chw"; // depthwise_conv_2d_nhwc_hwc // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) + if (isaDepthwiseConv2DNhwcHwcOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwc"; // conv_3d // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> @@ -511,90 +503,50 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() < 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; // conv_2d_nhwc_fhwc // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - if (indexingMaps.size() == 3 && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)) + if (isaConv2DNhwcFhwcOp(genericOp)) return "linalg.conv_2d_nhwc_fhwc"; // conv_2d_nhwc_hwcf // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)) + if (isaConv2DNhwcHwcfOp(genericOp)) return "linalg.conv_2d_nhwc_hwcf"; // conv_2d_nchw_fchw // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - if (indexingMaps.size() == 3 && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) + if (isaConv2DNchwFchwOp(genericOp)) return "linalg.conv_2d_nchw_fchw"; // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps) // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - if (indexingMaps.size() == 5 && - (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)) + if (isaConv2DNhwcFhwcQOp(genericOp)) return "linalg.conv_2d_nhwc_fhwc_q"; // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps) // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - if (indexingMaps.size() == 5 && - (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) + if (isaConv2DNchwFchwQOp(genericOp)) return "linalg.conv_2d_nchw_fchw_q"; // depthwise_conv_2d_nhwc_hwcm // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> - if (indexingMaps.size() == 3 && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)) + if (isaDepthwiseConv2DNhwcHwcmOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwcm"; // depthwise_conv_2d_nhwc_hwcm_q // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> - if (indexingMaps.size() == 5 && - (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)) + if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwcm_q"; return ""; } @@ -607,53 +559,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2)) + if (isaConv2DNgchwFgchwOp(genericOp)) return "linalg.conv_2d_ngchw_fgchw"; // conv_2d_ngchw_gfchw // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if (indexingMaps.size() == 3 && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)) + if (isaConv2DNgchwGfchwOp(genericOp)) return "linalg.conv_2d_ngchw_gfchw"; // conv_2d_ngchw_gfchw_q // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if (indexingMaps.size() == 5 && - (indexingMaps[2] == indexingMaps[3] && cast(indexingMaps[2]).getValue().getNumResults() == 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)) + if (isaConv2DNgchwGfchwQOp(genericOp)) return "linalg.conv_2d_ngchw_gfchw_q"; // conv_2d_nhwgc_gfhwc // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && - matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)) + if (isaConv2DNhwgcGfhwcOp(genericOp)) return "linalg.conv_2d_nhwgc_gfhwc"; // depthwise_conv_3d_ncdhw_cdhw // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)> diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 12f88caf08fc7..c5bb184c726f8 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -459,6 +459,283 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) { matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)); } +bool isaConv2DOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false; + + // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1)); +} + +bool isaConv2DNhwcFhwcOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); +} + +bool isaConv2DNhwcHwcfOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); +} + +bool isaConv2DNchwFchwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); +} + +bool isaConv2DNhwcFhwcQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); +} + +bool isaConv2DNchwFchwQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); +} + +bool isaConv2DNgchwFgchwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2)); +} + +bool isaConv2DNgchwGfchwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + return (indexingMaps.size() == 3 && + matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); +} + +bool isaConv2DNgchwGfchwQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); +} + +bool isaConv2DNhwgcGfhwcOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); +} + +bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3)); +} + +bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); +} + +bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); +} + +bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { From 053d912f321b72cde08df10ef5cc9690248247cc Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 8 Oct 2025 07:59:55 -0500 Subject: [PATCH 09/18] Clean a bit --- .../Dialect/Linalg/Transforms/Specialize.cpp | 97 +------------------ 1 file changed, 5 insertions(+), 92 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 968370c05615a..ea94b49946545 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -343,27 +343,14 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) { } static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() != 3) return ""; - // depthwise_conv_1d_ncw_cw - // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> - // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> if (isaDepthwiseConv1DNcwCwOp(genericOp)) return "linalg.depthwise_conv_1d_ncw_cw"; - // depthwise_conv_1d_nwc_wc - // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> - // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> if (isaDepthwiseConv1DNwcWcOp(genericOp)) return "linalg.depthwise_conv_1d_nwc_wc"; - // conv_2d - // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> - // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> if (isaConv2DOp(genericOp)) return "linalg.conv_2d"; + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; Block *body = genericOp.getBlock(); auto yieldOp = cast(body->getTerminator()); @@ -401,45 +388,24 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { } static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() != 3) return ""; - // depthwise_conv_1d_nwc_wcm - // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> - // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> - // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> if (isaDepthwiseConv1DNwcWcmOp(genericOp)) return "linalg.depthwise_conv_1d_nwc_wcm"; - // conv_1d_nwc_wcf - // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> - // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> - // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> if (isaConv1DNwcWcfOp(genericOp)) return "linalg.conv_1d_nwc_wcf"; - // conv_1d_ncw_fcw - // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> - // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> if (isaConv1DNcwFcwOp(genericOp)) return "linalg.conv_1d_ncw_fcw"; return ""; } static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() < 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - // depthwise_conv_2d_nchw_chw - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> if (isaDepthwiseConv2DNchwChwOp(genericOp)) return "linalg.depthwise_conv_2d_nchw_chw"; - // depthwise_conv_2d_nhwc_hwc - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> if (isaDepthwiseConv2DNhwcHwcOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwc"; + + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() < 3) return ""; + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; // conv_3d // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> @@ -501,83 +467,30 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { } static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() < 3) return ""; - // conv_2d_nhwc_fhwc - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (isaConv2DNhwcFhwcOp(genericOp)) return "linalg.conv_2d_nhwc_fhwc"; - // conv_2d_nhwc_hwcf - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (isaConv2DNhwcHwcfOp(genericOp)) return "linalg.conv_2d_nhwc_hwcf"; - // conv_2d_nchw_fchw - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (isaConv2DNchwFchwOp(genericOp)) return "linalg.conv_2d_nchw_fchw"; - // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps) - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (isaConv2DNhwcFhwcQOp(genericOp)) return "linalg.conv_2d_nhwc_fhwc_q"; - // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps) - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> if (isaConv2DNchwFchwQOp(genericOp)) return "linalg.conv_2d_nchw_fchw_q"; - // depthwise_conv_2d_nhwc_hwcm - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> if (isaDepthwiseConv2DNhwcHwcmOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwcm"; - // depthwise_conv_2d_nhwc_hwcm_q - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwcm_q"; return ""; } static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() < 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - // conv_2d_ngchw_fgchw - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if (isaConv2DNgchwFgchwOp(genericOp)) return "linalg.conv_2d_ngchw_fgchw"; - // conv_2d_ngchw_gfchw - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if (isaConv2DNgchwGfchwOp(genericOp)) return "linalg.conv_2d_ngchw_gfchw"; - // conv_2d_ngchw_gfchw_q - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if (isaConv2DNgchwGfchwQOp(genericOp)) return "linalg.conv_2d_ngchw_gfchw_q"; - // conv_2d_nhwgc_gfhwc - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> if (isaConv2DNhwgcGfhwcOp(genericOp)) return "linalg.conv_2d_nhwgc_gfhwc"; // depthwise_conv_3d_ncdhw_cdhw From 87b91ee5602ff9b031f98a980f6283d517190092 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 9 Oct 2025 02:36:54 -0500 Subject: [PATCH 10/18] Add 3D APIs --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 6 + .../Dialect/Linalg/Transforms/Specialize.cpp | 70 ++--------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 117 ++++++++++++++++++ 3 files changed, 132 insertions(+), 61 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index b4955625b6dec..ad5e0818b90f5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -134,6 +134,12 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op); bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op); bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op); bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op); +bool isaConv3DOp(LinalgOp op); +bool isaConv3DNcdhwFcdhwOp(LinalgOp op); +bool isaConv3DNdhwcDhwcfOp(LinalgOp op); +bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op); +bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op); +bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index ea94b49946545..6ecc6a024bed8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -406,13 +406,7 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); if (indexingMaps.size() < 3) return ""; unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - // conv_3d - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> - if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2)) + if (isaConv3DOp(genericOp)) return "linalg.conv_3d"; Block *body = genericOp.getBlock(); @@ -493,29 +487,14 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { return "linalg.conv_2d_ngchw_gfchw_q"; if (isaConv2DNhwgcGfhwcOp(genericOp)) return "linalg.conv_2d_nhwgc_gfhwc"; - // depthwise_conv_3d_ncdhw_cdhw - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4)) + if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp)) return "linalg.depthwise_conv_3d_ncdhw_cdhw"; - // depthwise_conv_3d_ndhwc_dhwc - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) + if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp)) return "linalg.depthwise_conv_3d_ndhwc_dhwc"; + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + if (indexingMaps.size() < 3) return ""; + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; Block *body = genericOp.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); @@ -541,42 +520,11 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { } static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) { - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() < 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - // conv_3d_ncdhw_fcdhw - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)) + if (isaConv3DNcdhwFcdhwOp(genericOp)) return "linalg.conv_3d_ncdhw_fcdhw"; - // conv_3d_ndhwc_dhwcf - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)) + if (isaConv3DNdhwcDhwcfOp(genericOp)) return "linalg.conv_3d_ndhwc_dhwcf"; - // depthwise_conv_3d_ndhwc_dhwcm - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && - matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)) + if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp)) return "linalg.depthwise_conv_3d_ndhwc_dhwcm"; return ""; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c5bb184c726f8..b3e79e8c1a409 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -736,6 +736,123 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) { matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); } +bool isaConv3DOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; + + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2)); +} + +bool isaConv3DNcdhwFcdhwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); +} + +bool isaConv3DNdhwcDhwcfOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); +} + +bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); +} + +bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4)); +} + +bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { From e5acca4e6d5c76cfd00688d487c16f6082ecd3a7 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 9 Oct 2025 02:43:28 -0500 Subject: [PATCH 11/18] Fix the NamedOp versions --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 40 ++++++++++++------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index b3e79e8c1a409..2d6d51d858853 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -460,7 +460,7 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) { } bool isaConv2DOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -475,7 +475,7 @@ bool isaConv2DOp(LinalgOp op) { } bool isaConv2DNhwcFhwcOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -494,7 +494,7 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op) { } bool isaConv2DNhwcHwcfOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -513,7 +513,7 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op) { } bool isaConv2DNchwFchwOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -532,7 +532,7 @@ bool isaConv2DNchwFchwOp(LinalgOp op) { } bool isaConv2DNhwcFhwcQOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -552,7 +552,7 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op) { } bool isaConv2DNchwFchwQOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -572,7 +572,7 @@ bool isaConv2DNchwFchwQOp(LinalgOp op) { } bool isaConv2DNgchwFgchwOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -593,7 +593,7 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op) { } bool isaConv2DNgchwGfchwOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -615,7 +615,7 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) { } bool isaConv2DNgchwGfchwQOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -637,7 +637,7 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op) { } bool isaConv2DNhwgcGfhwcOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -658,7 +658,7 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op) { } bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -677,7 +677,7 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) { } bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -696,7 +696,7 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) { } bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -716,7 +716,7 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) { } bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -737,7 +737,7 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) { } bool isaConv3DOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -753,7 +753,7 @@ bool isaConv3DOp(LinalgOp op) { } bool isaConv3DNcdhwFcdhwOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -773,7 +773,7 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op) { } bool isaConv3DNdhwcDhwcfOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -793,7 +793,7 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) { } bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -814,7 +814,7 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) { } bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -834,7 +834,7 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) { } bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; From 535f7e95288274f104b5c6fda01012016613d3b3 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 9 Oct 2025 03:22:39 -0500 Subject: [PATCH 12/18] Pooling ops' --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 15 + .../Dialect/Linalg/Transforms/Specialize.cpp | 243 ++----------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 328 ++++++++++++++++++ 3 files changed, 373 insertions(+), 213 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index ad5e0818b90f5..1a1b70d3eb979 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -140,6 +140,21 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op); bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op); bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op); bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op); +bool isaPoolingNchwMaxOp(LinalgOp op); +bool isaPoolingNchwSumOp(LinalgOp op); +bool isaPoolingNhwcMaxOp(LinalgOp op); +bool isaPoolingNhwcMinOp(LinalgOp op); +bool isaPoolingNhwcSumOp(LinalgOp op); +bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op); +bool isaPoolingNhwcMinUnsignedOp(LinalgOp op); +bool isaPoolingNcwMaxOp(LinalgOp op); +bool isaPoolingNcwSumOp(LinalgOp op); +bool isaPoolingNwcMaxOp(LinalgOp op); +bool isaPoolingNwcMinOp(LinalgOp op); +bool isaPoolingNwcSumOp(LinalgOp op); +bool isaPoolingNdhwcMaxOp(LinalgOp op); +bool isaPoolingNdhwcMinOp(LinalgOp op); +bool isaPoolingNdhwcSumOp(LinalgOp op); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 6ecc6a024bed8..aef3a1480d289 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,106 +237,6 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } -/// Utility to match block body for linalg.pool* ops. -template -static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { - Operation *defOp = yieldVal.getDefiningOp(); - // if (!defOp) return false; - if (!(isa_and_present(defOp) || ...)) return false; - - BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); - BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); - if (!lhsArg || !rhsArg) return false; - return true; -} - -static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); -} - -static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); -} - -static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); -} - -static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); -} - -static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); -} - -static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, - uint32_t mapIndex, uint32_t dimIndex) { - auto affineMap = cast(indexingMaps[mapIndex]).getValue(); - if (dimIndex < affineMap.getNumResults()) - return affineMap.getResult(dimIndex); - return nullptr; -} - -// Check if `expr` is either: -// - a dimension expr alone (implying *1), or -// - a multiplication of dimension expr by constant. -bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) { - if (auto dExpr = dyn_cast(expr)) { - dim = dExpr; - constantValue = 1; - return true; - } - - auto mulExpr = dyn_cast(expr); - if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) - return false; - - AffineExpr lhs = mulExpr.getLHS(); - AffineExpr rhs = mulExpr.getRHS(); - - if (auto dExpr = dyn_cast(lhs)) { - if (auto cst = dyn_cast(rhs)) { - dim = dExpr; - constantValue = cst.getValue(); - return true; - } - } - if (auto cst = dyn_cast(lhs)) { - if (auto dExpr = dyn_cast(rhs)) { - dim = dExpr; - constantValue = cst.getValue(); - return true; - } - } - return false; -} - -bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) { - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); - auto addExpr = dyn_cast(inpExpr); - if (!addExpr || addExpr.getKind() != AffineExprKind::Add) - return false; - - AffineExpr dim0, dim1; - // TODO(Abhishek-Varma): Use this information in specialize.cpp. - int64_t c0, c1; - - if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && - isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { - // Pattern matched with dims and constants extracted. - AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); - AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); - return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr)); - } - return false; -} - -bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) { - return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim); -} - static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) { if (isaConv1DOp(genericOp)) return "linalg.conv_1d"; return ""; @@ -349,41 +249,16 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { return "linalg.depthwise_conv_1d_nwc_wc"; if (isaConv2DOp(genericOp)) return "linalg.conv_2d"; - - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - Block *body = genericOp.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - // pooling_ncw_max - // pooling_ncw_sum - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> - // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> - // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) { - if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) - return "linalg.pooling_ncw_max"; - if (bodyMatcherForSumPoolOps(yieldVal, body)) - return "linalg.pooling_ncw_sum"; - } - // pooling_nwc_max - // pooling_nwc_min - // pooling_nwc_sum - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> - // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> - // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2)) { - if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) - return "linalg.pooling_nwc_max"; - if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) - return "linalg.pooling_nwc_min"; - if (bodyMatcherForSumPoolOps(yieldVal, body)) - return "linalg.pooling_nwc_sum"; - } + if (isaPoolingNcwMaxOp(genericOp)) + return "linalg.pooling_ncw_max"; + if (isaPoolingNcwSumOp(genericOp)) + return "linalg.pooling_ncw_sum"; + if (isaPoolingNwcMaxOp(genericOp)) + return "linalg.pooling_nwc_max"; + if (isaPoolingNwcMinOp(genericOp)) + return "linalg.pooling_nwc_min"; + if (isaPoolingNwcSumOp(genericOp)) + return "linalg.pooling_nwc_sum"; return ""; } @@ -402,61 +277,22 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { return "linalg.depthwise_conv_2d_nchw_chw"; if (isaDepthwiseConv2DNhwcHwcOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwc"; - - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() < 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; if (isaConv3DOp(genericOp)) return "linalg.conv_3d"; - - Block *body = genericOp.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - // pooling_nchw_max - // pooling_nchw_sum - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) { - if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) - return "linalg.pooling_nchw_max"; - if (bodyMatcherForSumPoolOps(yieldVal, body)) - return "linalg.pooling_nchw_sum"; - } - // pooling_nhwc_max - // pooling_nhwc_min - // pooling_nhwc_sum - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) { - if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) - return "linalg.pooling_nhwc_max"; - if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) - return "linalg.pooling_nhwc_min"; - if (bodyMatcherForSumPoolOps(yieldVal, body)) - return "linalg.pooling_nhwc_sum"; - } - // pooling_nhwc_max_unsigned - // pooling_nhwc_min_unsigned - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) { - if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) - return "linalg.pooling_nhwc_max_unsigned"; - if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) - return "linalg.pooling_nhwc_min_unsigned"; - } + if (isaPoolingNchwMaxOp(genericOp)) + return "linalg.pooling_nchw_max"; + if (isaPoolingNchwSumOp(genericOp)) + return "linalg.pooling_nchw_sum"; + if (isaPoolingNhwcMaxOp(genericOp)) + return "linalg.pooling_nhwc_max"; + if (isaPoolingNhwcMinOp(genericOp)) + return "linalg.pooling_nhwc_min"; + if (isaPoolingNhwcSumOp(genericOp)) + return "linalg.pooling_nhwc_sum"; + if (isaPoolingNhwcMaxUnsignedOp(genericOp)) + return "linalg.pooling_nhwc_max_unsigned"; + if (isaPoolingNhwcMinUnsignedOp(genericOp)) + return "linalg.pooling_nhwc_min_unsigned"; return ""; } @@ -491,31 +327,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { return "linalg.depthwise_conv_3d_ncdhw_cdhw"; if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp)) return "linalg.depthwise_conv_3d_ndhwc_dhwc"; - - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - if (indexingMaps.size() < 3) return ""; - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - Block *body = genericOp.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - // pooling_ndhwc_max - // pooling_ndhwc_min - // pooling_ndhwc_sum - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) { - if (bodyMatcherForMaxSignedPoolOps(yieldVal, body)) - return "linalg.pooling_ndhwc_max"; - if (bodyMatcherForMinSignedPoolOps(yieldVal, body)) - return "linalg.pooling_ndhwc_min"; - if (bodyMatcherForSumPoolOps(yieldVal, body)) - return "linalg.pooling_ndhwc_sum"; - } + if (isaPoolingNdhwcMaxOp(genericOp)) + return "linalg.pooling_ndhwc_max"; + if (isaPoolingNdhwcMinOp(genericOp)) + return "linalg.pooling_ndhwc_min"; + if (isaPoolingNdhwcSumOp(genericOp)) + return "linalg.pooling_ndhwc_sum"; return ""; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 2d6d51d858853..127e61d7db050 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -853,6 +853,334 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) { matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)); } +bool isaPoolingNchwMaxOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNchwSumOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) && + bodyMatcherForSumPoolOps(yieldVal, body)); +} + +bool isaPoolingNhwcMaxOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNhwcMinOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNhwcSumOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForSumPoolOps(yieldVal, body)); +} + +bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNcwMaxOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> + // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNcwSumOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> + // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && + bodyMatcherForSumPoolOps(yieldVal, body)); +} + +bool isaPoolingNwcMaxOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNwcMinOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNwcSumOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForSumPoolOps(yieldVal, body)); +} + +bool isaPoolingNdhwcMaxOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNdhwcMinOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); +} + +bool isaPoolingNdhwcSumOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> + // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForSumPoolOps(yieldVal, body)); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { From b06ba750c040a5b2506271ccf96d60fceb02cdef Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 9 Oct 2025 03:30:56 -0500 Subject: [PATCH 13/18] Updated maps --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 66 ++++++++++++------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 127e61d7db050..e847f2cf3aef2 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -909,9 +909,9 @@ bool isaPoolingNhwcMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && @@ -931,9 +931,9 @@ bool isaPoolingNhwcMinOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && @@ -953,9 +953,9 @@ bool isaPoolingNhwcSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && @@ -1019,9 +1019,9 @@ bool isaPoolingNcwMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> - // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> - // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && @@ -1040,9 +1040,9 @@ bool isaPoolingNcwSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> - // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> - // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && @@ -1061,9 +1061,9 @@ bool isaPoolingNwcMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> - // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> - // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && @@ -1082,9 +1082,9 @@ bool isaPoolingNwcMinOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> - // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> - // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && @@ -1103,9 +1103,9 @@ bool isaPoolingNwcSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> - // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> - // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && @@ -1124,9 +1124,9 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && @@ -1147,9 +1147,9 @@ bool isaPoolingNdhwcMinOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && @@ -1170,9 +1170,9 @@ bool isaPoolingNdhwcSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> - // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && From 5aeb3716ff7ddb376bb4a9db962431dba1b55b56 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 9 Oct 2025 04:42:10 -0500 Subject: [PATCH 14/18] Missing ops --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 4 + .../Dialect/Linalg/Transforms/Specialize.cpp | 16 ++++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 83 +++++++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 1a1b70d3eb979..2f7868bd55182 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -128,15 +128,19 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op); bool isaConv2DNchwFchwQOp(LinalgOp op); bool isaConv2DNgchwFgchwOp(LinalgOp op); bool isaConv2DNgchwGfchwOp(LinalgOp op); +bool isaConv2DNhwcHwcfQOp(LinalgOp op); +bool isaConv2DNhwgcGfhwcQOp(LinalgOp op); bool isaConv2DNgchwGfchwQOp(LinalgOp op); bool isaConv2DNhwgcGfhwcOp(LinalgOp op); bool isaDepthwiseConv2DNchwChwOp(LinalgOp op); bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op); bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op); +bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op); bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op); bool isaConv3DOp(LinalgOp op); bool isaConv3DNcdhwFcdhwOp(LinalgOp op); bool isaConv3DNdhwcDhwcfOp(LinalgOp op); +bool isaConv3DNdhwcDhwcfQOp(LinalgOp op); bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op); bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op); bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index aef3a1480d289..031cb3b919b96 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -277,6 +277,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { return "linalg.depthwise_conv_2d_nchw_chw"; if (isaDepthwiseConv2DNhwcHwcOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwc"; + if (isaDepthwiseConv2DNhwcHwcQOp(genericOp)) + return "linalg.depthwise_conv_2d_nhwc_hwc_q"; if (isaConv3DOp(genericOp)) return "linalg.conv_3d"; if (isaPoolingNchwMaxOp(genericOp)) @@ -307,6 +309,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { return "linalg.conv_2d_nhwc_fhwc_q"; if (isaConv2DNchwFchwQOp(genericOp)) return "linalg.conv_2d_nchw_fchw_q"; + if (isaConv2DNhwcHwcfQOp(genericOp)) + return "linalg.conv_2d_nhwc_hwcf_q"; if (isaDepthwiseConv2DNhwcHwcmOp(genericOp)) return "linalg.depthwise_conv_2d_nhwc_hwcm"; if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp)) @@ -323,6 +327,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { return "linalg.conv_2d_ngchw_gfchw_q"; if (isaConv2DNhwgcGfhwcOp(genericOp)) return "linalg.conv_2d_nhwgc_gfhwc"; + if (isaConv2DNhwgcGfhwcQOp(genericOp)) + return "linalg.conv_2d_nhwgc_gfhwc_q"; if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp)) return "linalg.depthwise_conv_3d_ncdhw_cdhw"; if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp)) @@ -341,6 +347,8 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) { return "linalg.conv_3d_ncdhw_fcdhw"; if (isaConv3DNdhwcDhwcfOp(genericOp)) return "linalg.conv_3d_ndhwc_dhwcf"; + if (isaConv3DNdhwcDhwcfQOp(genericOp)) + return "linalg.conv_3d_ndhwc_dhwcf_q"; if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp)) return "linalg.depthwise_conv_3d_ndhwc_dhwcm"; return ""; @@ -412,6 +420,10 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nhwc_hwcf_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") { @@ -420,12 +432,16 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.conv_3d") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf_q") { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") { namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index e847f2cf3aef2..b239f62a7049d 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -614,6 +614,48 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) { matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); } +bool isaConv2DNhwcHwcfQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); +} + +bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4) + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); +} + bool isaConv2DNgchwGfchwQOp(LinalgOp op) { if (isa(op)) return true; @@ -736,6 +778,26 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) { matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); } +bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); +} + bool isaConv3DOp(LinalgOp op) { if (isa(op)) return true; @@ -792,6 +854,27 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) { matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); } +bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) { + if (isa(op)) return true; + + if (!isaConvolutionOpInterface(op)) return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 4; + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> + return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); +} + bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) { if (isa(op)) return true; From f1b8e80ff65c4db082b634eadc3976c35dc4ccac Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 14 Oct 2025 04:37:28 -0500 Subject: [PATCH 15/18] Make use of dilations/strides info --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 84 +-- .../Dialect/Linalg/Transforms/Specialize.cpp | 355 +++++------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 546 ++++++++++++------ 3 files changed, 548 insertions(+), 437 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 2f7868bd55182..44ebc101d7c37 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -115,50 +115,50 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); //===----------------------------------------------------------------------===// bool isaConv1DOp(LinalgOp op); -bool isaConv1DNwcWcfOp(LinalgOp op); -bool isaConv1DNcwFcwOp(LinalgOp op); -bool isaDepthwiseConv1DNcwCwOp(LinalgOp op); -bool isaDepthwiseConv1DNwcWcOp(LinalgOp op); -bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op); +bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); bool isaConv2DOp(LinalgOp op); -bool isaConv2DNhwcFhwcOp(LinalgOp op); -bool isaConv2DNhwcHwcfOp(LinalgOp op); -bool isaConv2DNchwFchwOp(LinalgOp op); -bool isaConv2DNhwcFhwcQOp(LinalgOp op); -bool isaConv2DNchwFchwQOp(LinalgOp op); -bool isaConv2DNgchwFgchwOp(LinalgOp op); -bool isaConv2DNgchwGfchwOp(LinalgOp op); -bool isaConv2DNhwcHwcfQOp(LinalgOp op); -bool isaConv2DNhwgcGfhwcQOp(LinalgOp op); -bool isaConv2DNgchwGfchwQOp(LinalgOp op); -bool isaConv2DNhwgcGfhwcOp(LinalgOp op); -bool isaDepthwiseConv2DNchwChwOp(LinalgOp op); -bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op); -bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op); -bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op); -bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op); +bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); bool isaConv3DOp(LinalgOp op); -bool isaConv3DNcdhwFcdhwOp(LinalgOp op); -bool isaConv3DNdhwcDhwcfOp(LinalgOp op); -bool isaConv3DNdhwcDhwcfQOp(LinalgOp op); -bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op); -bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op); -bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op); -bool isaPoolingNchwMaxOp(LinalgOp op); -bool isaPoolingNchwSumOp(LinalgOp op); -bool isaPoolingNhwcMaxOp(LinalgOp op); -bool isaPoolingNhwcMinOp(LinalgOp op); -bool isaPoolingNhwcSumOp(LinalgOp op); -bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op); -bool isaPoolingNhwcMinUnsignedOp(LinalgOp op); -bool isaPoolingNcwMaxOp(LinalgOp op); -bool isaPoolingNcwSumOp(LinalgOp op); -bool isaPoolingNwcMaxOp(LinalgOp op); -bool isaPoolingNwcMinOp(LinalgOp op); -bool isaPoolingNwcSumOp(LinalgOp op); -bool isaPoolingNdhwcMaxOp(LinalgOp op); -bool isaPoolingNdhwcMinOp(LinalgOp op); -bool isaPoolingNdhwcSumOp(LinalgOp op); +bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNchwSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNcwSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNwcMinOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNwcSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 031cb3b919b96..94dfbcc15d055 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,250 +237,169 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } -static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) { - if (isaConv1DOp(genericOp)) return "linalg.conv_1d"; - return ""; +template +static FailureOr specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef dilations, ArrayRef strides) { + SmallVector inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; } -static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) { - if (isaDepthwiseConv1DNcwCwOp(genericOp)) - return "linalg.depthwise_conv_1d_ncw_cw"; - if (isaDepthwiseConv1DNwcWcOp(genericOp)) - return "linalg.depthwise_conv_1d_nwc_wc"; +static FailureOr inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConv1DOp(genericOp)) return specializeToConvOp(rewriter, genericOp, dilations, strides); + return failure(); +} + +static FailureOr inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector dilations, strides; + if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); if (isaConv2DOp(genericOp)) - return "linalg.conv_2d"; - if (isaPoolingNcwMaxOp(genericOp)) - return "linalg.pooling_ncw_max"; - if (isaPoolingNcwSumOp(genericOp)) - return "linalg.pooling_ncw_sum"; - if (isaPoolingNwcMaxOp(genericOp)) - return "linalg.pooling_nwc_max"; - if (isaPoolingNwcMinOp(genericOp)) - return "linalg.pooling_nwc_min"; - if (isaPoolingNwcSumOp(genericOp)) - return "linalg.pooling_nwc_sum"; - return ""; + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNcwSumOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNwcMinOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNwcSumOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + return failure(); } -static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) { - if (isaDepthwiseConv1DNwcWcmOp(genericOp)) - return "linalg.depthwise_conv_1d_nwc_wcm"; - if (isaConv1DNwcWcfOp(genericOp)) - return "linalg.conv_1d_nwc_wcf"; - if (isaConv1DNcwFcwOp(genericOp)) - return "linalg.conv_1d_ncw_fcw"; - return ""; +static FailureOr inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector dilations, strides; + if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + return failure(); } -static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) { - if (isaDepthwiseConv2DNchwChwOp(genericOp)) - return "linalg.depthwise_conv_2d_nchw_chw"; - if (isaDepthwiseConv2DNhwcHwcOp(genericOp)) - return "linalg.depthwise_conv_2d_nhwc_hwc"; - if (isaDepthwiseConv2DNhwcHwcQOp(genericOp)) - return "linalg.depthwise_conv_2d_nhwc_hwc_q"; +static FailureOr inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector dilations, strides; + if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); if (isaConv3DOp(genericOp)) - return "linalg.conv_3d"; - if (isaPoolingNchwMaxOp(genericOp)) - return "linalg.pooling_nchw_max"; - if (isaPoolingNchwSumOp(genericOp)) - return "linalg.pooling_nchw_sum"; - if (isaPoolingNhwcMaxOp(genericOp)) - return "linalg.pooling_nhwc_max"; - if (isaPoolingNhwcMinOp(genericOp)) - return "linalg.pooling_nhwc_min"; - if (isaPoolingNhwcSumOp(genericOp)) - return "linalg.pooling_nhwc_sum"; - if (isaPoolingNhwcMaxUnsignedOp(genericOp)) - return "linalg.pooling_nhwc_max_unsigned"; - if (isaPoolingNhwcMinUnsignedOp(genericOp)) - return "linalg.pooling_nhwc_min_unsigned"; - return ""; + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNchwSumOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + return failure(); } -static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) { - if (isaConv2DNhwcFhwcOp(genericOp)) - return "linalg.conv_2d_nhwc_fhwc"; - if (isaConv2DNhwcHwcfOp(genericOp)) - return "linalg.conv_2d_nhwc_hwcf"; - if (isaConv2DNchwFchwOp(genericOp)) - return "linalg.conv_2d_nchw_fchw"; - if (isaConv2DNhwcFhwcQOp(genericOp)) - return "linalg.conv_2d_nhwc_fhwc_q"; - if (isaConv2DNchwFchwQOp(genericOp)) - return "linalg.conv_2d_nchw_fchw_q"; - if (isaConv2DNhwcHwcfQOp(genericOp)) - return "linalg.conv_2d_nhwc_hwcf_q"; - if (isaDepthwiseConv2DNhwcHwcmOp(genericOp)) - return "linalg.depthwise_conv_2d_nhwc_hwcm"; - if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp)) - return "linalg.depthwise_conv_2d_nhwc_hwcm_q"; - return ""; +static FailureOr inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + return failure(); } -static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) { - if (isaConv2DNgchwFgchwOp(genericOp)) - return "linalg.conv_2d_ngchw_fgchw"; - if (isaConv2DNgchwGfchwOp(genericOp)) - return "linalg.conv_2d_ngchw_gfchw"; - if (isaConv2DNgchwGfchwQOp(genericOp)) - return "linalg.conv_2d_ngchw_gfchw_q"; - if (isaConv2DNhwgcGfhwcOp(genericOp)) - return "linalg.conv_2d_nhwgc_gfhwc"; - if (isaConv2DNhwgcGfhwcQOp(genericOp)) - return "linalg.conv_2d_nhwgc_gfhwc_q"; - if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp)) - return "linalg.depthwise_conv_3d_ncdhw_cdhw"; - if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp)) - return "linalg.depthwise_conv_3d_ndhwc_dhwc"; - if (isaPoolingNdhwcMaxOp(genericOp)) - return "linalg.pooling_ndhwc_max"; - if (isaPoolingNdhwcMinOp(genericOp)) - return "linalg.pooling_ndhwc_min"; - if (isaPoolingNdhwcSumOp(genericOp)) - return "linalg.pooling_ndhwc_sum"; - return ""; +static FailureOr inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + return failure(); } -static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) { - if (isaConv3DNcdhwFcdhwOp(genericOp)) - return "linalg.conv_3d_ncdhw_fcdhw"; - if (isaConv3DNdhwcDhwcfOp(genericOp)) - return "linalg.conv_3d_ndhwc_dhwcf"; - if (isaConv3DNdhwcDhwcfQOp(genericOp)) - return "linalg.conv_3d_ndhwc_dhwcf_q"; - if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp)) - return "linalg.depthwise_conv_3d_ndhwc_dhwcm"; - return ""; +static FailureOr inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides)) + return specializeToConvOp(rewriter, genericOp, dilations, strides); + return failure(); } -static std::string inferConvolutionKind(GenericOp genericOp) { +// Converts linalg.generic to named linalg.*conv* where possible. +static FailureOr inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); unsigned totalIterators = iteratorTypes.size(); switch(totalIterators) { case 2: - return inferBasedOnRank2ConvIteratorTypes(genericOp); + return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); case 4: - return inferBasedOnRank4ConvIteratorTypes(genericOp); + return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); case 5: - return inferBasedOnRank5ConvIteratorTypes(genericOp); + return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); case 6: - return inferBasedOnRank6ConvIteratorTypes(genericOp); + return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); case 7: - return inferBasedOnRank7ConvIteratorTypes(genericOp); + return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); case 8: - return inferBasedOnRank8ConvIteratorTypes(genericOp); + return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); case 9: - return inferBasedOnRank9ConvIteratorTypes(genericOp); - } - return ""; -} - -// Converts linalg.generic to named linalg.*conv* where possible. -static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, - GenericOp genericOp) { - std::string convKind = inferConvolutionKind(genericOp); - if (convKind == "") return failure(); - SmallVector inputs = genericOp.getDpsInputs(); - ValueRange outputs = genericOp.getDpsInits(); - SmallVector indexingMaps = genericOp.getIndexingMapsArray(); - SmallVector resultTypes = genericOp.hasPureTensorSemantics() - ? TypeRange(ValueRange(outputs)) - : TypeRange{}; - LinalgOp namedOp; - if (convKind == "linalg.conv_1d") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_1d_nwc_wcf") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_1d_ncw_fcw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_1d_ncw_cw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_1d_nwc_wcm") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nhwc_fhwc") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nhwc_hwcf") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nchw_fchw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nhwc_fhwc_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nchw_fchw_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_ngchw_fgchw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_ngchw_gfchw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_ngchw_gfchw_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nhwc_hwcf_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_3d") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf_q") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nchw_max") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nchw_sum") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nhwc_max") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nhwc_min") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nhwc_sum") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nhwc_max_unsigned") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nhwc_min_unsigned") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_ncw_max") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_ncw_sum") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nwc_max") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nwc_min") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_nwc_sum") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_ndhwc_max") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_ndhwc_min") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); - } else if (convKind == "linalg.pooling_ndhwc_sum") { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); } - return namedOp; - return failure(); } @@ -566,7 +485,7 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, // Convolution - e.g. *conv* if (isaConvolutionOpInterface(genericOp)) { - return specializeLinalgConvolutions(rewriter, genericOp); + return inferAndSpecializeToConvolutionOp(rewriter, genericOp); } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index b239f62a7049d..548f43f83b0ed 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -319,7 +319,8 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_ return false; } -static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) { +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, + int64_t& dilation, int64_t& stride) { unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); auto addExpr = dyn_cast(inpExpr); @@ -327,7 +328,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un return false; AffineExpr dim0, dim1; - // TODO(Abhishek-Varma): Use this information in specialize.cpp. int64_t c0, c1; if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && @@ -335,7 +335,15 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un // Pattern matched with dims and constants extracted. AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); - return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr)); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } else if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } } return false; } @@ -354,6 +362,16 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef return true; } +static bool updateConvDilationsAndStrides(SmallVector* dilations, SmallVector* strides, ArrayRef tempDilations, ArrayRef tempStrides) { + if (!(dilations && strides)) + return true; + for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { + dilations->push_back(dilation); + strides->push_back(stride); + } + return true; +} + bool isaConv1DOp(LinalgOp op) { if (isa(op)) return true; @@ -365,10 +383,12 @@ bool isaConv1DOp(LinalgOp op) { // #map = affine_map<(d0, d1) -> (d0 + d1)> // #map1 = affine_map<(d0, d1) -> (d1)> // #map2 = affine_map<(d0, d1) -> (d0)> - return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0); + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); + return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]); } -bool isaConv1DNwcWcfOp(LinalgOp op) { +bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -377,16 +397,20 @@ bool isaConv1DNwcWcfOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv1DNcwFcwOp(LinalgOp op) { +bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -395,16 +419,20 @@ bool isaConv1DNcwFcwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) { +bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -413,16 +441,21 @@ bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2)); + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) { +// ------------------- +bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -431,16 +464,20 @@ bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1)); + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) { +bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -449,14 +486,18 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } bool isaConv2DOp(LinalgOp op) { @@ -467,14 +508,16 @@ bool isaConv2DOp(LinalgOp op) { ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false; + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> - return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1)); + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1])); } -bool isaConv2DNhwcFhwcOp(LinalgOp op) { +bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -483,17 +526,21 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNhwcHwcfOp(LinalgOp op) { +bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -502,17 +549,21 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNchwFchwOp(LinalgOp op) { +bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -521,17 +572,21 @@ bool isaConv2DNchwFchwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNhwcFhwcQOp(LinalgOp op) { +bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -540,18 +595,22 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNchwFchwQOp(LinalgOp op) { +bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -560,18 +619,22 @@ bool isaConv2DNchwFchwQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNgchwFgchwOp(LinalgOp op) { +bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -580,19 +643,23 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNgchwGfchwOp(LinalgOp op) { +bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -601,20 +668,23 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - return (indexingMaps.size() == 3 && - matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNhwcHwcfQOp(LinalgOp op) { +bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -623,18 +693,22 @@ bool isaConv2DNhwcHwcfQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) { +bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -643,20 +717,24 @@ bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4) - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNgchwGfchwQOp(LinalgOp op) { +bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -665,20 +743,24 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv2DNhwgcGfhwcOp(LinalgOp op) { +bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -687,19 +769,23 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) { +bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -708,17 +794,21 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3)); + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) { +bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -727,17 +817,21 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) { +bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -746,18 +840,22 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) { +bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -766,19 +864,23 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) { +bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -787,15 +889,19 @@ bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } bool isaConv3DOp(LinalgOp op) { @@ -806,15 +912,17 @@ bool isaConv3DOp(LinalgOp op) { ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> - return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2)); + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[2], tempStrides[2])); } -bool isaConv3DNcdhwFcdhwOp(LinalgOp op) { +bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -823,18 +931,22 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv3DNdhwcDhwcfOp(LinalgOp op) { +bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -843,18 +955,22 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) { +bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -863,19 +979,23 @@ bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 4; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) { +bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -884,19 +1004,23 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) { +bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -905,18 +1029,22 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4)); + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4, tempDilations[2], tempStrides[2])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) { +bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -925,18 +1053,22 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) { if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false; unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNchwMaxOp(LinalgOp op) { +bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -948,17 +1080,21 @@ bool isaPoolingNchwMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) && bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNchwSumOp(LinalgOp op) { +bool isaPoolingNchwSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -970,17 +1106,21 @@ bool isaPoolingNchwSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) && bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNhwcMaxOp(LinalgOp op) { +bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -992,17 +1132,21 @@ bool isaPoolingNhwcMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNhwcMinOp(LinalgOp op) { +bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1014,17 +1158,21 @@ bool isaPoolingNhwcMinOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNhwcSumOp(LinalgOp op) { +bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1036,17 +1184,21 @@ bool isaPoolingNhwcSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) { +bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1058,17 +1210,21 @@ bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) { +bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1080,17 +1236,21 @@ bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2,1); + SmallVector tempStrides(2,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNcwMaxOp(LinalgOp op) { +bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1102,16 +1262,20 @@ bool isaPoolingNcwMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNcwSumOp(LinalgOp op) { +bool isaPoolingNcwSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1123,16 +1287,20 @@ bool isaPoolingNcwSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNwcMaxOp(LinalgOp op) { +bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1144,16 +1312,20 @@ bool isaPoolingNwcMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNwcMinOp(LinalgOp op) { +bool isaPoolingNwcMinOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1165,16 +1337,20 @@ bool isaPoolingNwcMinOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNwcSumOp(LinalgOp op) { +bool isaPoolingNwcSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1186,16 +1362,20 @@ bool isaPoolingNwcSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(1,1); + SmallVector tempStrides(1,1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNdhwcMaxOp(LinalgOp op) { +bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1207,18 +1387,22 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNdhwcMinOp(LinalgOp op) { +bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1230,18 +1414,22 @@ bool isaPoolingNdhwcMinOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } -bool isaPoolingNdhwcSumOp(LinalgOp op) { +bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { if (isa(op)) return true; if (!isaConvolutionOpInterface(op)) return false; @@ -1253,15 +1441,19 @@ bool isaPoolingNdhwcSumOp(LinalgOp op) { auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(3,1); + SmallVector tempStrides(3,1); // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) && + bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); } Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, From 1a9417d3623074ba40ee95c7ef94c5c14b678d87 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 14 Oct 2025 06:19:23 -0500 Subject: [PATCH 16/18] Add lit test and clean up --- .../Dialect/Linalg/Transforms/Specialize.cpp | 5 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 9 + ...oundtrip-linalg-convolution-named-ops.mlir | 615 ++++++++++++++++++ 3 files changed, 627 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 94dfbcc15d055..12eb17ef0a435 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,7 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` with `dilations` and `strides`. template static FailureOr specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef dilations, ArrayRef strides) { SmallVector inputs = genericOp.getDpsInputs(); @@ -380,7 +381,7 @@ static FailureOr inferAndSpecializeBasedOnRank9ConvIteratorTypes(Rewri return failure(); } -// Converts linalg.generic to named linalg.*conv* where possible. +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To improve the search speed, the convolution ops have been segregated based on the rank of iterator types array. static FailureOr inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); unsigned totalIterators = iteratorTypes.size(); @@ -483,7 +484,7 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, return specializeLinalgContractions(rewriter, genericOp); } - // Convolution - e.g. *conv* + // Convolution - e.g. *conv/pooling* if (isaConvolutionOpInterface(genericOp)) { return inferAndSpecializeToConvolutionOp(rewriter, genericOp); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 548f43f83b0ed..0d4e8aa5e6382 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -319,6 +319,11 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_ return false; } +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, CST_1 and CST_2 can be any constant. static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t& dilation, int64_t& stride) { unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; @@ -348,10 +353,13 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un return false; } +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[aIndex].getResult(aDim) == indexingMaps[bIndex].getResult(bDim) static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) { return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim); } +/// Give an array of AffineMaps, verify each map to be of the corresponding `expectedSize`. static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef expectedSizes) { if (indexingMaps.size() != expectedSizes.size()) return false; @@ -362,6 +370,7 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef return true; } +/// Utility to update `dilations` and `strides` by copy the corresponding data from `tempDilations` and `tempStrides`. static bool updateConvDilationsAndStrides(SmallVector* dilations, SmallVector* strides, ArrayRef tempDilations, ArrayRef tempStrides) { if (!(dilations && strides)) return true; diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir new file mode 100644 index 0000000000000..8cd57044e613f --- /dev/null +++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir @@ -0,0 +1,615 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s + +func.func @conv_1d_nwc_wcf(%input: memref, %filter: memref, %output: memref) { + linalg.conv_1d_nwc_wcf {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_1d_nwc_wcf +// CHECK: linalg.conv_1d_nwc_wcf +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_1d_ncw_fcw(%input: memref, %filter: memref, %output: memref) { + linalg.conv_1d_ncw_fcw {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_1d_ncw_fcw +// CHECK: linalg.conv_1d_ncw_fcw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_1d(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv_1d ins(%in, %filter : memref, memref) + outs(%out : memref) + return +} +// CHECK: @conv_1d +// CHECK: linalg.conv_1d +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_1d_ncw_cw(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_ncw_cw {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_ncw_cw +// CHECK: linalg.depthwise_conv_1d_ncw_cw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_1d_nwc_wc(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_nwc_wc +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_1d_nwc_wcm(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_nwc_wcm +// CHECK: linalg.depthwise_conv_1d_nwc_wcm +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv_2d ins(%in, %filter : memref, memref) + outs(%out: memref) + return +} +// CHECK: @conv_2d +// CHECK: linalg.conv_2d +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nchw_fchw(%arg0 : tensor, + %arg1 : tensor, %arg2 : tensor) -> + (tensor<4x8x12x16xf32>, tensor) { + %0 = linalg.conv_2d_nchw_fchw {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %1 = tensor.cast %0 : tensor to tensor<4x8x12x16xf32> + return %1, %0 : tensor<4x8x12x16xf32>, tensor +} +// CHECK: @conv_2d_nchw_fchw +// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-SAME: dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nchw_fchw_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %inputzp, %filterzp: tensor, tensor, i32, i32) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nchw_fchw_q +// CHECK: linalg.conv_2d_nchw_fchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_ngchw_fgchw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_fgchw +// CHECK: linalg.conv_2d_ngchw_fgchw +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_ngchw_gfchw(%input: memref, %filter: memref, %output: memref) { + linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_2d_ngchw_gfchw +// CHECK: linalg.conv_2d_ngchw_gfchw +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_ngchw_gfchw_q(%input: memref, %filter: memref, %inputzp: i32, %filterzp: i32, %output: memref) { + linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %inputzp, %filterzp: memref, memref, i32, i32) + outs (%output: memref) + return +} +// CHECK: @conv_2d_ngchw_gfchw_q +// CHECK: linalg.conv_2d_ngchw_gfchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_hwcf_q(%input: memref, %filter: memref, %inputzp: i32, %filterzp: i32, %output: memref) { + linalg.conv_2d_nhwc_hwcf_q { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter, %inputzp, %filterzp : memref, memref, i32, i32) outs(%output : memref) + return +} +// CHECK: @conv_2d_nhwc_hwcf_q +// CHECK: linalg.conv_2d_nhwc_hwcf_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwgc_gfhwc_q(%input: memref, %filter: memref, %inputzp: i32, %filterzp: i32, %output: memref) { + linalg.conv_2d_nhwgc_gfhwc_q { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter, %inputzp, %filterzp : memref, memref, i32, i32) outs(%output : memref) + return +} +// CHECK: @conv_2d_nhwgc_gfhwc_q +// CHECK: linalg.conv_2d_nhwgc_gfhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %output: tensor) -> tensor{ + %res = linalg.depthwise_conv_2d_nhwc_hwc_q { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter, %inputzp, %filterzp : tensor, tensor, i32, i32) outs(%output : tensor) -> tensor + return %res : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_fhwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc +// CHECK: linalg.conv_2d_nhwc_fhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_fhwc_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %inputzp, %filterzp: tensor, tensor, i32, i32) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc_q +// CHECK: linalg.conv_2d_nhwc_fhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwc_hwcf(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_hwcf +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_2d_nhwgc_gfhwc(%input: memref, %filter: memref, %output: memref) { + linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @conv_2d_nhwgc_gfhwc +// CHECK: linalg.conv_2d_nhwgc_gfhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d(%in : memref, %filter : memref, %out : memref) -> () { + linalg.conv_3d ins(%in, %filter : memref, memref) + outs(%out : memref) + return +} +// CHECK: @conv_3d +// CHECK: linalg.conv_3d +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d_ncdhw_fcdhw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ncdhw_fcdhw +// CHECK: linalg.conv_3d_ncdhw_fcdhw +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d_ndhwc_dhwcf(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ndhwc_dhwcf +// CHECK: linalg.conv_3d_ndhwc_dhwcf +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @conv_3d_ndhwc_dhwcf_q(%input: tensor, %filter: tensor, %inputzp: i32, %filterzp: i32, %init: tensor) -> tensor { + %0 = linalg.conv_3d_ndhwc_dhwcf_q {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins(%input, %filter, %inputzp, %filterzp : tensor, tensor, i32, i32) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ndhwc_dhwcf_q +// CHECK: linalg.conv_3d_ndhwc_dhwcf_q +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ncdhw_cdhw +// CHECK: linalg.depthwise_conv_3d_ncdhw_cdhw +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwc +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwc +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nchw_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nchw_max +// CHECK: linalg.pooling_nchw_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nchw_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nchw_sum +// CHECK: linalg.pooling_nchw_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ncw_max(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_ncw_max {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ncw_max +// CHECK: linalg.pooling_ncw_max +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ncw_sum(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_ncw_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ncw_sum +// CHECK: linalg.pooling_ncw_sum +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ndhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_max +// CHECK: linalg.pooling_ndhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ndhwc_min(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_min +// CHECK: linalg.pooling_ndhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_ndhwc_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_ndhwc_sum +// CHECK: linalg.pooling_ndhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nwc_max(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_max +// CHECK: linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nwc_min(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_nwc_min {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_min +// CHECK: linalg.pooling_nwc_min +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nwc_sum(%input: tensor, %output: tensor, %filter: tensor) -> tensor { + %0 = linalg.pooling_nwc_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter: tensor, tensor) + outs(%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nwc_sum +// CHECK: linalg.pooling_nwc_sum +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic From bba3921a36b934ee96f4bc4794e47dd329e9995b Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 14 Oct 2025 07:22:14 -0500 Subject: [PATCH 17/18] Format code --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 153 +- .../Dialect/Linalg/Transforms/Specialize.cpp | 221 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 2066 ++++++++++------- 3 files changed, 1527 insertions(+), 913 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 44ebc101d7c37..0f39098ca9946 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -115,50 +115,119 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); //===----------------------------------------------------------------------===// bool isaConv1DOp(LinalgOp op); -bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); bool isaConv2DOp(LinalgOp op); -bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNhwcFhwcQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNchwFchwQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNgchwFgchwOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNgchwGfchwOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNhwcHwcfQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNgchwGfchwQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv2DNhwgcGfhwcOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); bool isaConv3DOp(LinalgOp op); -bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNchwSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNcwSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNwcMinOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNwcSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); -bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector* dilations = nullptr, SmallVector* strides = nullptr); +bool isaConv3DNcdhwFcdhwOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv3DNdhwcDhwcfOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNchwSumOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNcwSumOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNwcMinOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNwcSumOp(LinalgOp op, SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNdhwcMaxOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNdhwcMinOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); +bool isaPoolingNdhwcSumOp(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 12eb17ef0a435..e08705b90e7b0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,9 +237,12 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } -/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` with `dilations` and `strides`. +/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. template -static FailureOr specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef dilations, ArrayRef strides) { +static FailureOr +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef dilations, ArrayRef strides) { SmallVector inputs = genericOp.getDpsInputs(); ValueRange outputs = genericOp.getDpsInits(); SmallVector indexingMaps = genericOp.getIndexingMapsArray(); @@ -247,159 +250,227 @@ static FailureOr specializeToConvOp(RewriterBase &rewriter, GenericOp ? TypeRange(ValueRange(outputs)) : TypeRange{}; LinalgOp namedOp; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs); + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, + inputs, outputs); } else { Attribute stridesAttr = rewriter.getI64TensorAttr(strides); Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); } return namedOp; } -static FailureOr inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr +inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; - if (isaConv1DOp(genericOp)) return specializeToConvOp(rewriter, genericOp, dilations, strides); + if (isaConv1DOp(genericOp)) + return specializeToConvOp(rewriter, genericOp, dilations, + strides); return failure(); } -static FailureOr inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr +inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaConv2DOp(genericOp)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, dilations, + strides); if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNcwSumOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNwcMinOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNwcSumOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); return failure(); } -static FailureOr inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr +inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); return failure(); } -static FailureOr inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr +inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaConv3DOp(genericOp)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, dilations, + strides); if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNchwSumOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); return failure(); } -static FailureOr inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr +inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); return failure(); } -static FailureOr inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr +inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); return failure(); } -static FailureOr inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr +inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp(rewriter, genericOp, + dilations, strides); if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides)) - return specializeToConvOp(rewriter, genericOp, dilations, strides); + return specializeToConvOp( + rewriter, genericOp, dilations, strides); return failure(); } -// Converts linalg.generic to named linalg.*conv/pooling* where possible. To improve the search speed, the convolution ops have been segregated based on the rank of iterator types array. -static FailureOr inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { - SmallVector iteratorTypes = genericOp.getIteratorTypesArray(); +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To +// improve the search speed, the convolution ops have been segregated based on +// the rank of iterator types array. +static FailureOr +inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector iteratorTypes = + genericOp.getIteratorTypesArray(); unsigned totalIterators = iteratorTypes.size(); - switch(totalIterators) { - case 2: - return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); - case 4: - return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); - case 5: - return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); - case 6: - return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); - case 7: - return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); - case 8: - return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); - case 9: - return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); + switch (totalIterators) { + case 2: + return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); + case 4: + return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); + case 5: + return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); + case 6: + return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); + case 7: + return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); + case 8: + return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); + case 9: + return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 0d4e8aa5e6382..8ea5e7a10e17e 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -249,28 +249,34 @@ template static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { Operation *defOp = yieldVal.getDefiningOp(); // if (!defOp) return false; - if (!(isa_and_present(defOp) || ...)) return false; + if (!(isa_and_present(defOp) || ...)) + return false; - BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); - BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); - if (!lhsArg || !rhsArg) return false; + BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); + BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + if (!lhsArg || !rhsArg) + return false; return true; } static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); + return bodyMatcherForPoolOps(yieldVal, + body); } static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); + return bodyMatcherForPoolOps(yieldVal, + body); } static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); + return bodyMatcherForPoolOps(yieldVal, + body); } static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { - return bodyMatcherForPoolOps(yieldVal, body); + return bodyMatcherForPoolOps(yieldVal, + body); } static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { @@ -288,7 +294,8 @@ static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, // Check if `expr` is either: // - a dimension expr alone (implying *1), or // - a multiplication of dimension expr by constant. -static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) { +static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, + int64_t &constantValue) { if (auto dExpr = dyn_cast(expr)) { dim = dExpr; constantValue = 1; @@ -320,12 +327,13 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_ } /// Given an array of AffineMaps `indexingMaps` verify the following :- -/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[0].getResult(iDim) == /// indexingMaps[1].getResult(fDim) * + /// indexingMaps[n-1].getResult(oDim) * /// where, CST_1 and CST_2 can be any constant. -static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, - int64_t& dilation, int64_t& stride) { +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); auto addExpr = dyn_cast(inpExpr); @@ -354,24 +362,37 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un } /// Given an array of AffineMaps `indexingMaps` verify the following :- -/// indexingMaps[aIndex].getResult(aDim) == indexingMaps[bIndex].getResult(bDim) -static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) { - return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim); -} - -/// Give an array of AffineMaps, verify each map to be of the corresponding `expectedSize`. -static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef expectedSizes) { - if (indexingMaps.size() != expectedSizes.size()) return false; +/// indexingMaps[aIndex].getResult(aDim) == +/// indexingMaps[bIndex].getResult(bDim) +static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, + unsigned aDim, unsigned bIndex, + unsigned bDim) { + return getAffineMapDim(indexingMaps, aIndex, aDim) == + getAffineMapDim(indexingMaps, bIndex, bDim); +} + +/// Give an array of AffineMaps, verify each map to be of the corresponding +/// `expectedSize`. +static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, + ArrayRef expectedSizes) { + if (indexingMaps.size() != expectedSizes.size()) + return false; - for (auto [indexingMap, expectedSize] : llvm::zip_equal(indexingMaps, expectedSizes)) { + for (auto [indexingMap, expectedSize] : + llvm::zip_equal(indexingMaps, expectedSizes)) { auto affineMap = cast(indexingMap).getValue(); - if (affineMap.getNumResults() != expectedSize) return false; + if (affineMap.getNumResults() != expectedSize) + return false; } return true; } -/// Utility to update `dilations` and `strides` by copy the corresponding data from `tempDilations` and `tempStrides`. -static bool updateConvDilationsAndStrides(SmallVector* dilations, SmallVector* strides, ArrayRef tempDilations, ArrayRef tempStrides) { +/// Utility to update `dilations` and `strides` by copy the corresponding data +/// from `tempDilations` and `tempStrides`. +static bool updateConvDilationsAndStrides(SmallVector *dilations, + SmallVector *strides, + ArrayRef tempDilations, + ArrayRef tempStrides) { if (!(dilations && strides)) return true; for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { @@ -382,1087 +403,1540 @@ static bool updateConvDilationsAndStrides(SmallVector* dilations, Small } bool isaConv1DOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {1,1,1})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {1, 1, 1})) + return false; + // #map = affine_map<(d0, d1) -> (d0 + d1)> // #map1 = affine_map<(d0, d1) -> (d1)> // #map2 = affine_map<(d0, d1) -> (d0)> - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); - return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]); + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, tempDilations[0], + tempStrides[0]); } -bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; +bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); } // ------------------- -bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; +bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); } bool isaConv2DOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + if (!verifyConvIndexingMapSizes(indexingMaps, {2, 2, 2})) + return false; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> - return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1])); + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[1], + tempStrides[1])); } -bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; +bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} -bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} -bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} -bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, + // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, + // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, + // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, + // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, + // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4) - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && - matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} -bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, + // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4) + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, + // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, + // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> + // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, + // d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && - matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, + // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], + tempStrides[1])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} -bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, + // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, + // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 0, 0, 4})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); } bool isaConv3DOp(LinalgOp op) { - if (isa(op)) return true; + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3})) + return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> - return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[2], tempStrides[2])); -} - -bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[2], + tempStrides[2])); +} + +bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, + // d3 + d7, d4 + d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d1, d5, d6, d7, d8)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, + /*oDim=*/4, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 4; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()> - // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> ()> #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> + // (d0, d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && - matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, + // d7, d8) -> (d0, d1, d2, d3, d8, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4, tempDilations[2], tempStrides[2])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + // + d5, d3 + d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, + // d4, d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d7, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, + /*oDim=*/4, tempDilations[2], + tempStrides[2])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5})) + return false; + unsigned iIndex = 0, fIndex = 1, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + + // d5, d3 + d6, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d0, d1, d2, d3, d7)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) && - bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNchwSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNchwSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) && - bodyMatcherForSumPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - bodyMatcherForMinSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - bodyMatcherForSumPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(2,1); - SmallVector tempStrides(2,1); + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && - bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNcwSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNcwSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) && - bodyMatcherForSumPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNwcMinOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - bodyMatcherForMinSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNwcSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - SmallVector tempDilations(1,1); - SmallVector tempStrides(1,1); + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)> // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && - bodyMatcherForSumPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; - if (!isaConvolutionOpInterface(op)) return false; + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && - bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && - bodyMatcherForMinSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); -} - -bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector* dilations, SmallVector* strides) { - if (isa(op)) return true; - - if (!isaConvolutionOpInterface(op)) return false; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; ArrayAttr indexingMaps = op.getIndexingMaps(); - if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false; - + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5})) + return false; + Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); unsigned iIndex = 0, oIndex = 2; - - SmallVector tempDilations(3,1); - SmallVector tempStrides(3,1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> - bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && - bodyMatcherForSumPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + + // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> + // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, + // d1, d2, d3, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); } Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, From 3852dc4ffeac76056c4e31ac74397ade5f3dc228 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 15 Oct 2025 03:26:58 -0500 Subject: [PATCH 18/18] Export just a single API --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 116 +---- .../Dialect/Linalg/Transforms/Specialize.cpp | 132 +++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 452 ++++++++++++++---- 3 files changed, 450 insertions(+), 250 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 0f39098ca9946..771d753a8bddb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -111,123 +111,13 @@ std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); //===----------------------------------------------------------------------===// -// Convolution matcher utilities +// Convolution matcher utility //===----------------------------------------------------------------------===// -bool isaConv1DOp(LinalgOp op); -bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DOp(LinalgOp op); -bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNhwcFhwcQOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNchwFchwQOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNgchwFgchwOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNgchwGfchwOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNhwcHwcfQOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, +template +bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations = nullptr, SmallVector *strides = nullptr); -bool isaConv2DNgchwGfchwQOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv2DNhwgcGfhwcOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv3DOp(LinalgOp op); -bool isaConv3DNcdhwFcdhwOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv3DNdhwcDhwcfOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNchwSumOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNcwSumOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNwcMinOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNwcSumOp(LinalgOp op, SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNdhwcMaxOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNdhwcMinOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); -bool isaPoolingNdhwcSumOp(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index e08705b90e7b0..929904fa2c510 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -268,7 +268,7 @@ static FailureOr inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; - if (isaConv1DOp(genericOp)) + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); return failure(); @@ -278,28 +278,35 @@ static FailureOr inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; - if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaConv2DOp(genericOp)) + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNcwSumOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNwcMinOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNwcSumOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); return failure(); @@ -309,13 +316,16 @@ static FailureOr inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; - if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); return failure(); @@ -325,37 +335,47 @@ static FailureOr inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; - if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaConv3DOp(genericOp)) + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNchwSumOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); return failure(); @@ -365,28 +385,36 @@ static FailureOr inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; - if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); return failure(); @@ -396,34 +424,44 @@ static FailureOr inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; - if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); return failure(); @@ -433,16 +471,20 @@ static FailureOr inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; - if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) return specializeToConvOp(rewriter, genericOp, dilations, strides); - if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides)) + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); return failure(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 8ea5e7a10e17e..13235d99887a7 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,15 +240,14 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } -// ------------------------------- -// ---------- CONV --------------- -// ------------------------------- +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// /// Utility to match block body for linalg.pool* ops. template static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { Operation *defOp = yieldVal.getDefiningOp(); - // if (!defOp) return false; if (!(isa_and_present(defOp) || ...)) return false; @@ -402,7 +401,7 @@ static bool updateConvDilationsAndStrides(SmallVector *dilations, return true; } -bool isaConv1DOp(LinalgOp op) { +static bool isaConv1DOp(LinalgOp op) { if (isa(op)) return true; @@ -423,8 +422,8 @@ bool isaConv1DOp(LinalgOp op) { tempStrides[0]); } -bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -453,8 +452,8 @@ bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -483,8 +482,9 @@ bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -514,8 +514,9 @@ bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector *dilations, } // ------------------- -bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -544,8 +545,9 @@ bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -575,7 +577,7 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DOp(LinalgOp op) { +static bool isaConv2DOp(LinalgOp op) { if (isa(op)) return true; @@ -599,8 +601,8 @@ bool isaConv2DOp(LinalgOp op) { tempStrides[1])); } -bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -632,8 +634,8 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -665,8 +667,8 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -698,8 +700,8 @@ bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -732,8 +734,8 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -766,8 +768,8 @@ bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -802,8 +804,8 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -838,8 +840,8 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -872,8 +874,8 @@ bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -908,8 +910,8 @@ bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -945,8 +947,8 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -981,8 +983,9 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1014,8 +1017,9 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1047,8 +1051,9 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1081,8 +1086,9 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1116,8 +1122,9 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1150,7 +1157,7 @@ bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv3DOp(LinalgOp op) { +static bool isaConv3DOp(LinalgOp op) { if (isa(op)) return true; @@ -1177,8 +1184,8 @@ bool isaConv3DOp(LinalgOp op) { tempStrides[2])); } -bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1214,8 +1221,8 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1251,8 +1258,8 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1289,9 +1296,9 @@ bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1328,8 +1335,9 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, tempDilations, tempStrides); } -bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1365,8 +1373,9 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1402,8 +1411,8 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1438,8 +1447,8 @@ bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNchwSumOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNchwSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1474,8 +1483,8 @@ bool isaPoolingNchwSumOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1510,8 +1519,8 @@ bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1546,8 +1555,8 @@ bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1582,8 +1591,9 @@ bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1618,8 +1628,9 @@ bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1654,8 +1665,8 @@ bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1687,8 +1698,8 @@ bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNcwSumOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNcwSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1720,8 +1731,8 @@ bool isaPoolingNcwSumOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1753,8 +1764,8 @@ bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNwcMinOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1786,8 +1797,8 @@ bool isaPoolingNwcMinOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNwcSumOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1819,8 +1830,8 @@ bool isaPoolingNwcSumOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1859,8 +1870,8 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1899,8 +1910,8 @@ bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +static bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -1939,6 +1950,263 @@ bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } +template +bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if constexpr (std::is_same_v) { + return isaConv1DOp(op); + } else if constexpr (std::is_same_v) { + return isaConv1DNwcWcfOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv1DNcwFcwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNcwCwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNwcWcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DOp(op); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcFhwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcHwcfOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNchwFchwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcFhwcQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNchwFchwQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNgchwFgchwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNgchwGfchwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwcHwcfQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwgcGfhwcQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNgchwGfchwQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv2DNhwgcGfhwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNhwcHwcmQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv3DOp(op); + } else if constexpr (std::is_same_v) { + return isaConv3DNcdhwFcdhwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv3DNdhwcDhwcfOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaConv3DNdhwcDhwcfQOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNcdhwCdhwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNdhwcDhwcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNchwMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNchwSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNcwMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNcwSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNwcSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNdhwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNdhwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNdhwcSumOp(op, dilations, strides); + } else { + return false; + } +} + +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool +isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) {