diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c336660335a6..82d11ec0737a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7119,6 +7119,7 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ printDefaultTorchOp(printer, *this, 9, 1); } }]; + let hasCanonicalizer = 1; } def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 0df423f41cbf..108882881df6 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6,6 +6,7 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallVector.h" #define DEBUG_TYPE "torch-mlir-torch-dialect" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -5898,6 +5899,160 @@ void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add>(context); } +namespace { +class CanonicalizeConvolutionWithSingleIntTuple + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtenConvolutionOp op, + PatternRewriter &rewriter) const override { + + auto weight = op.getWeight(); + auto weightType = dyn_cast(weight.getType()); + + if (!weightType) { + return rewriter.notifyMatchFailure(op, "weight is not a vtensor"); + } + auto optionalSizes = weightType.getOptionalSizes(); + if (!optionalSizes.has_value()) { + return rewriter.notifyMatchFailure(op, + "unranked weight tensor unsupported!"); + } + + // The rank is the size of the dimensions array + int64_t weightRank = optionalSizes.value().size(); + + // We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv). + if (weightRank < 4 || weightRank > 5) { + return rewriter.notifyMatchFailure( + op, "unsupported weight rank (must be 4 or 5)"); + } + int requiredSpatialDims = weightRank - 2; + + // Validate stride, padding, output_padding, and dilation are constant + // lists. + SmallVector strideInts; + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) { + return rewriter.notifyMatchFailure(op, + "non-const int stride unsupported!"); + } + SmallVector paddingInts; + if (!matchPattern(op.getPadding(), + m_TorchListOfConstantInts(paddingInts))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + + SmallVector dilationInts; + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) { + return rewriter.notifyMatchFailure( + op, "non-const int tranposed unsupported!"); + } + + SmallVector outputPaddingInts; + if (!matchPattern(op.getOutputPadding(), + m_TorchListOfConstantInts(outputPaddingInts))) { + return rewriter.notifyMatchFailure( + op, "non-const int output_padding unsupported!"); + } + + // Canonicalization Logic: Only rewrite if padding provided is 1 element + // but the convolution requires 2 or 3 elements. + auto isCanonical = [requiredSpatialDims](ArrayRef param) { + return param.size() == static_cast(requiredSpatialDims); + }; + + if (isCanonical(strideInts) && isCanonical(paddingInts) && + isCanonical(dilationInts)) { + return rewriter.notifyMatchFailure( + op, "stride, padding, dialtion and outputPadding is already fully " + "specified"); + } + + if (transposed && isCanonical(outputPaddingInts)) { + return rewriter.notifyMatchFailure( + op, "output_padding is already fully specified"); + } + + expand(strideInts, requiredSpatialDims); + expand(paddingInts, requiredSpatialDims); + expand(dilationInts, requiredSpatialDims); + + if (transposed) + expand(outputPaddingInts, requiredSpatialDims); + + // Construct the new List + // For example: If user provided padding=[1], and we need 2 or 3 dims, we + // create padding=[1, 1] or padding = [1,1,1] + Location loc = op.getLoc(); + SmallVector cstPadding, cstStrides, cstDilation, cstOutputPadding; + + for (auto dim : llvm::seq(0, requiredSpatialDims)) { + + cstStrides.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(strideInts[dim]))); + + cstPadding.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(paddingInts[dim]))); + + cstDilation.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dilationInts[dim]))); + + if (transposed) + cstOutputPadding.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(outputPaddingInts[dim]))); + } + + auto targetListType = + Torch::ListType::get(Torch::IntType::get(op->getContext())); + + // Create the list construct op + auto stridesList = Torch::PrimListConstructOp::create( + rewriter, loc, targetListType, cstStrides); + auto paddingList = Torch::PrimListConstructOp::create( + rewriter, loc, targetListType, cstPadding); + auto dilationsList = Torch::PrimListConstructOp::create( + rewriter, loc, targetListType, cstDilation); + + Value outputPaddingList; + if (transposed) { + outputPaddingList = Torch::PrimListConstructOp::create( + rewriter, loc, targetListType, cstOutputPadding); + } else { + outputPaddingList = op.getOutputPadding(); + } + + // Replace the Op + // We create a new convolution op, keeping all operands the same except + // stride, padding,dilation, and output_padding which are now fully + // specified + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInput(), op.getWeight(), op.getBias(), + stridesList.getResult(), paddingList.getResult(), + dilationsList.getResult(), op.getTransposed(), outputPaddingList, + op.getGroups()); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// AtenConvolutionOp Registration +//===----------------------------------------------------------------------===// +void AtenConvolutionOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // AtenLinalgCrossOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6911821a8882..73e858816fed 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1133,8 +1133,10 @@ "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Convolution2DStaticModule_basic", + "Convolution2DSingleIntTupleModule_basic", "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "ConvolutionModule2DTransposeScalarTupleParams_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dStaticModule_basic", @@ -2168,6 +2170,7 @@ "Conv2dWithValidPaddingModule_basic", "Conv2dWithSamePaddingModule_basic", "Convolution2DStaticModule_basic", + "Convolution2DSingleIntTupleModule_basic", "CosineSimilarityStaticModule_basic", "DetachModule_basic", "DropoutEvalFloatModule_basic", @@ -2908,6 +2911,7 @@ "Conv2dWithSamePaddingModule_basic", "Conv2dWithValidPaddingModule_basic", "Conv3dModule_basic", + "Conv3dModuleScalarTupleParams_basic", "Conv3dWithSamePaddingModule_basic", "Conv3dWithValidPaddingModule_basic", "ConvolutionModule3DGroups_basic", @@ -2923,7 +2927,9 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", + "Convolution2DSingleIntTupleModule_basic", "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", + "ConvolutionModule2DTransposeScalarTupleParams_basic", "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", # Error: onnx lowering, @@ -3694,6 +3700,7 @@ "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv3dModule_basic", + "Conv3dModuleScalarTupleParams_basic", "Conv3dWithSamePaddingModule_basic", "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", @@ -4333,6 +4340,7 @@ "Conv2dWithSamePaddingModule_basic", "Conv2dWithValidPaddingModule_basic", "Conv3dModule_basic", + "Conv3dModuleScalarTupleParams_basic", "Conv3dWithSamePaddingModule_basic", "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", @@ -4340,6 +4348,7 @@ "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", + "Convolution2DSingleIntTupleModule_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStatic_basic", "ConvolutionBackwardModule2DStrided_basic", @@ -4347,6 +4356,7 @@ "ConvolutionModule2DGroups_basic", "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "ConvolutionModule2DTransposeScalarTupleParams_basic", "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", "ConvolutionModule2DGroupedTranspose_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index df75d9427480..7d5e65c21cef 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -612,7 +612,8 @@ def emit_with_mutating_variants(key, **kwargs): "aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)" ) emit( - "aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)" + "aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)", + has_canonicalizer=True, ) emit( "aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)" diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 25c6b03f5424..974bf0d56962 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -304,6 +304,37 @@ def Convolution2DStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) +class Convolution2DSingleIntTupleModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 3, 10, 10], torch.float32, True), + ([3, 3, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=(1,), + padding=(0,), + dilation=(1,), + transposed=False, + output_padding=[0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Convolution2DSingleIntTupleModule()) +def Convolution2DSingleIntTupleModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + + class Convolution2DStridedModule(torch.nn.Module): def __init__(self): super().__init__() @@ -901,6 +932,39 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3)) +class ConvolutionModule2DTransposeScalarTupleParams(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution( + inputVec, + weight, + bias=None, + stride=(1,), + padding=(1,), + dilation=(1,), + transposed=True, + output_padding=(0,), + groups=1, + ) + + +@register_test_case( + module_factory=lambda: ConvolutionModule2DTransposeScalarTupleParams() +) +def ConvolutionModule2DTransposeScalarTupleParams_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + + class Conv_Transpose1dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1569,6 +1633,39 @@ def Conv3dWithValidPaddingModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv3dModuleScalarTupleParams(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv3d( + inputVec, + weight, + bias=bias, + stride=(1,), + padding=(0,), + dilation=(1,), + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv3dModuleScalarTupleParams()) +def Conv3dModuleScalarTupleParams_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6, 6) + weight = torch.randn(8, 2, 3, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class ConvTbcModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index b7156b364928..dcf24ac6842d 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3623,3 +3623,31 @@ func.func @torch.aten.avg_pool2d.single_int_tuple(%arg0: !torch.vtensor<[2,4,20, %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false_0, %none : !torch.vtensor<[2,4,20,20],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32> return %3 : !torch.vtensor<[2,4,9,9],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution.single_int_tuple( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> { +// CHECK: %[[ISTRANSPOSE:.*]] = torch.constant.bool true +// CHECK: %[[WEIGHTS:.*]] = torch.vtensor.literal(dense<-7.486820e-03> : tensor<1x1x1x1xf32>) : !torch.vtensor<[1,1,1,1],f32> +// CHECK: %[[BIAS:.*]] = torch.vtensor.literal(dense<0.536443591> : tensor<1xf32>) : !torch.vtensor<[1],f32> +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUT:.*]] = torch.aten.convolution %[[ARG0]], %[[WEIGHTS]], %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[ISTRANSPOSE]], %[[OUTPUT_PADDING]], %[[INT1]] : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> +func.func @torch.aten.convolution.single_int_tuple(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> { + %w = torch.vtensor.literal(dense<-7.486820e-03> : tensor<1x1x1x1xf32>) : !torch.vtensor<[1,1,1,1],f32> + %b = torch.vtensor.literal(dense<0.536443591> : tensor<1xf32>) : !torch.vtensor<[1],f32> + %int1 = torch.constant.int 1 + %stride = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %int0 = torch.constant.int 0 + %padding = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %true = torch.constant.bool true + %output_padding = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %6 = torch.aten.convolution %arg0, %w, %b, %stride, %padding, %dilation, %true, %output_padding, %int1 : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + return %6 : !torch.vtensor<[1,1,5,5],f32> +}