diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 23d35b3dc2cc..2aa6a0383e5b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7744,6 +7744,7 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ printDefaultTorchOp(printer, *this, 6, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [ @@ -7857,6 +7858,7 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [ printDefaultTorchOp(printer, *this, 6, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [ @@ -8001,6 +8003,7 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ printDefaultTorchOp(printer, *this, 7, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [ @@ -8060,6 +8063,7 @@ def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [ printDefaultTorchOp(printer, *this, 7, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [ diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 78c489387aa2..227ed82d9ee9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -6046,12 +6046,6 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl ¶ms, if constexpr (std::is_same() || std::is_same()) params.push_back(val); - - if constexpr (std::is_same() || - std::is_same()) { - if (params.size() == 1) - params.push_back(params[0]); - } } // Checks the validity of pooling parameters and stores them in the respective diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e50be5ff97ae..555073f604ca 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5601,6 +5601,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns( }); } +namespace { + +void expand(SmallVectorImpl ¶ms, int numSpatialDims) { + if (params.size() == 1) { + for ([[maybe_unused]] int dim : llvm::seq(0, numSpatialDims - 1)) { + params.push_back(params[0]); + } + } +} + +template +LogicalResult expandPoolParams(AtenPoolOpT op, int numSpatialDims, + mlir::PatternRewriter &rewriter, + Value &kernelSizeList, Value &stridesList, + Value &paddingList, Value &dilationsList) { + + SmallVector kernelSizeInts, strideInts, paddingInts, dilationInts; + if (!matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(kernelSizeInts))) + return rewriter.notifyMatchFailure( + op, "Non-const kernel_size for pooling op unsupported"); + + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) + return rewriter.notifyMatchFailure( + op, "Non-const padding factor for pooling op unsupported"); + + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) + return rewriter.notifyMatchFailure( + op, "Non-const stride for pooling op unsupported"); + + if constexpr (std::is_same() || + std::is_same()) { + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) + return rewriter.notifyMatchFailure( + op, "Non-const dilation for pooling op unsupported"); + + if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 && + strideInts.size() != 1 && dilationInts.size() != 1) { + return rewriter.notifyMatchFailure( + op, + "Expected one of kernel/stride/padding/dilation to be singleton."); + } + + expand(dilationInts, numSpatialDims); + + } else if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 && + strideInts.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Expected one of kernel/stride/padding to be singleton."); + } + + // expand singleton elements + expand(kernelSizeInts, numSpatialDims); + expand(paddingInts, numSpatialDims); + expand(strideInts, numSpatialDims); + + Location loc = op.getLoc(); + + SmallVector cstKernel, cstPadding, cstStrides, cstDilations; + for (auto dim : llvm::seq(0, kernelSizeInts.size())) { + cstKernel.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(kernelSizeInts[dim]))); + cstPadding.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[dim]))); + cstStrides.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[dim]))); + } + + // set dilations separately as for AvgPool op it won't be set + for (auto dim : llvm::seq(0, dilationInts.size())) { + cstDilations.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(dilationInts[dim]))); + } + + auto targetListType = + Torch::ListType::get(Torch::IntType::get(op->getContext())); + kernelSizeList = rewriter.create( + loc, targetListType, cstKernel); + paddingList = rewriter.create(loc, targetListType, + cstPadding); + stridesList = rewriter.create(loc, targetListType, + cstStrides); + dilationsList = rewriter.create( + loc, targetListType, cstDilations); + + return success(); +} + +template +struct CanonicalizeAvgPoolWithSingleIntTuple + : public mlir::OpRewritePattern { + CanonicalizeAvgPoolWithSingleIntTuple(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult + matchAndRewrite(AvgPoolOpT op, + mlir::PatternRewriter &rewriter) const override { + Value kernel, stride, pad, dilations; + + auto numSpatialDims = 2; + if constexpr (std::is_same()) + numSpatialDims = 3; + + // Attempt to expand params if necessary. + if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride, + pad, dilations))) + return rewriter.notifyMatchFailure( + op, "Failed to expand params for AvgPooling"); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.getSelf(), kernel, stride, pad, + op.getCeilMode(), op.getCountIncludePad(), op.getDivisorOverride()); + return success(); + } +}; + +template +struct CanonicalizeMaxPoolWithSingleIntTuple + : public mlir::OpRewritePattern { + CanonicalizeMaxPoolWithSingleIntTuple(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult + matchAndRewrite(MaxPoolOpT op, + mlir::PatternRewriter &rewriter) const override { + Value kernel, stride, pad, dilations; + + auto numSpatialDims = 2; + if constexpr (std::is_same()) + numSpatialDims = 3; + + // Attempt to expand params if necessary. + if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride, + pad, dilations))) + return rewriter.notifyMatchFailure( + op, "Failed to expand params for MaxPooling"); + + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + op.getSelf(), kernel, stride, pad, + dilations, op.getCeilMode()); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// AtenAvgPool2dOp +//===----------------------------------------------------------------------===// +void AtenAvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add>(context); +} + +//===----------------------------------------------------------------------===// +// AtenAvgPool3dOp +//===----------------------------------------------------------------------===// +void AtenAvgPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add>(context); +} + +//===----------------------------------------------------------------------===// +// AtenMaxPool2dOp +//===----------------------------------------------------------------------===// +void AtenMaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add>(context); +} + +//===----------------------------------------------------------------------===// +// AtenMaxPool3dOp +//===----------------------------------------------------------------------===// +void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add>(context); +} + //===----------------------------------------------------------------------===// // AtenLinalgCrossOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index dee936116a73..67cacfe69f9c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -539,8 +539,6 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", - "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", - "AvgPool2dSingleIntTupleParamsModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic", "RollModule_basic", } @@ -985,8 +983,6 @@ } FX_IMPORTER_STABLEHLO_CRASHING_SET = { - "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", - "AvgPool2dSingleIntTupleParamsModule_basic", "BatchNorm1DModule_basic", "BatchNorm2DModule_basic", "BatchNorm3DModule_basic", @@ -2841,6 +2837,7 @@ "AvgPool1dPadCeilPadNotIncluded_basic", "AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", + "AvgPool3dSingleIntTupleParamsModule_basic", "BatchMlpLayerModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", @@ -3028,6 +3025,7 @@ "MaxPool2dWithIndicesNonDefaultDilationModule_basic", "MaxPool2dWithIndicesNonDefaultParamsModule_basic", "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dSingleIntTupleParamsModule_basic", "MaxPool3dCeilModeTrueModule_basic", "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", @@ -3039,6 +3037,7 @@ "MaxPool3dWithIndicesNonDefaultDilationModule_basic", "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", + "MaxPool3dSingleIntTupleParamsModule_basic", "MaxUnpool3dModule_basic", "MaxUnpool3dModulePad0_basic", "MeanDimEmptyDimModule_basic", @@ -3529,6 +3528,7 @@ "AvgPool3dStaticModule_basic", "AvgPool3dCountIncludePadFalse_basic", "AvgPool3dCountIncludePadFalseWithoutPadding_basic", + "AvgPool3dSingleIntTupleParamsModule_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dStaticModule_basic", @@ -3782,6 +3782,7 @@ "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", + "MaxPool3dSingleIntTupleParamsModule_basic", "MeanDimEmptyDimModule_basic", "MlGroupNormManualModule_basic", "MlGroupNormModule_basic", @@ -4205,6 +4206,7 @@ "AvgPool2dIntModule_basic", "AvgPool2dStaticModule_basic", "AvgPool2dWithoutPadModule_basic", + "AvgPool3dSingleIntTupleParamsModule_basic", "BatchMlpLayerModule_basic", "BernoulliFloatModule_basic", "BernoulliModule_basic", @@ -4612,6 +4614,7 @@ "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", + "MaxPool3dSingleIntTupleParamsModule_basic", "MeanDimAllReduceKeepdimModule_basic", "MeanDimAllReduceModule_basic", "MeanDimDtypeModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 891f1e5f8bcb..84685cc10b9a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -657,7 +657,10 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) - emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit( + "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)", + has_canonicalizer=True, + ) emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", @@ -666,7 +669,10 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) - emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit( + "aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)", + has_canonicalizer=True, + ) emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", @@ -677,13 +683,15 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)") emit( - "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)", + has_canonicalizer=True, ) emit( "aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) emit( - "aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + "aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)", + has_canonicalizer=True, ) emit( "aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 9ef3cffb2193..5ab5cba60186 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -550,6 +550,31 @@ def MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic(module, tu: TestUtil module.forward(tu.rand(1, 1, 75, 75, low=-1)) +class MaxPool2dSingleIntTupleParamsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mpd = torch.nn.MaxPool2d( + kernel_size=(6,), + stride=(2, 2), + padding=(1, 1), + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.mpd(x) + + +@register_test_case(module_factory=lambda: MaxPool2dSingleIntTupleParamsModule()) +def MaxPool2dSingleIntTupleParamsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) + + # ============================================================================== @@ -722,6 +747,32 @@ def MaxPool3dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, 20, low=0.5, high=1.0)) +class MaxPool3dSingleIntTupleParamsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mpd = torch.nn.MaxPool3d( + kernel_size=(6, 6, 6), + stride=(2, 2, 2), + padding=(1, 1, 1), + dilation=(2,), + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.mpd(x) + + +@register_test_case(module_factory=lambda: MaxPool3dSingleIntTupleParamsModule()) +def MaxPool3dSingleIntTupleParamsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, 20, 20, low=0.5, high=1.0)) + + # ============================================================================== @@ -1810,6 +1861,31 @@ def AvgPool3dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) +class AvgPool3dSingleIntTupleParamsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.apd = torch.nn.AvgPool3d( + kernel_size=(6, 6, 6), + stride=(2,), + padding=(1, 1, 1), + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.apd(x) + + +@register_test_case(module_factory=lambda: AvgPool3dSingleIntTupleParamsModule()) +def AvgPool3dSingleIntTupleParamsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, 20, 20, low=0.5, high=1.0)) + + # ============================================================================== diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a025ec09726d..7607872a46a3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3389,3 +3389,30 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1 torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %3 : !torch.vtensor<[?],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.avg_pool2d.single_int_tuple( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[C_6:.*]] = torch.constant.int 6 +// CHECK: %[[C_1:.*]] = torch.constant.int 1 +// CHECK: %[[C_2:.*]] = torch.constant.int 2 +// CHECK: %[[KERNEL:.*]] = torch.prim.ListConstruct %[[C_6]], %[[C_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C_1]], %[[C_1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[C_2]], %[[C_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[POOL:.*]] = torch.aten.avg_pool2d %[[ARG0]], %[[KERNEL]], %[[PAD]], %[[STRIDE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[2,4,20,20],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32> +// CHECK: return %[[POOL]] +func.func @torch.aten.avg_pool2d.single_int_tuple(%arg0: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> { + %int6 = torch.constant.int 6 + %0 = torch.prim.ListConstruct %int6 : (!torch.int) -> !torch.list + %int2 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %int1 = torch.constant.int 1 + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %false = torch.constant.bool false + %false_0 = torch.constant.bool false + %none = torch.constant.none + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false_0, %none : !torch.vtensor<[2,4,20,20],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32> + return %3 : !torch.vtensor<[2,4,9,9],f32> + }