From aca33f1742096e7e6cb3152be15140cf9f71e508 Mon Sep 17 00:00:00 2001 From: Felix Schneider Date: Tue, 22 Oct 2024 20:26:16 +0200 Subject: [PATCH 01/19] [TorchToLinalg] Use Op with native channel order for quantized conv2d (#3807) I've upstreamed the necessary quantized linalg Op with the "channel-first" ordering used by torch (https://github.com/llvm/llvm-project/pull/107740) for 2d convolution. This patch changes the lowering for the quantized 2d case of `aten.convolution` accordingly, which saves three transpositions per convolution (input, weights, result) and therefore removes the requirement to try to optimize these away in downstream passes. --- lib/Conversion/TorchToLinalg/Linear.cpp | 59 ++++++++++--------- .../Conversion/TorchToLinalg/convolution.mlir | 8 +-- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index a4962d12abdc..9c914690bbf4 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1125,54 +1125,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { } if (numGroups == 1 && inputZp) { - // The quantized version uses a different channel ordering so we need to - // permute the tensors in order to use the existing path. We should - // eventually directly support this channel ordering. - llvm::SmallVector inPerms, weightPerms; - inPerms.push_back(0); // N stays at the front for input. - // Then we expect the spatial dimensions - for (size_t i = 0; i < numSpatialDims; ++i) { - inPerms.push_back(i + 2); - weightPerms.push_back(i + 2); - } - inPerms.push_back(1); - weightPerms.append({1, 0}); - - paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); - weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); - outputTensor = - transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); - switch (numSpatialDims) { case 2: conv = rewriter - .create( + .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; - case 3: + case 3: { + // The quantized version uses a different channel ordering so we need to + // permute the tensors in order to use the existing path. We should + // eventually directly support this channel ordering. + llvm::SmallVector inPerms, weightPerms; + inPerms.push_back(0); // N stays at the front for input. + // Then we expect the spatial dimensions + for (size_t i = 0; i < numSpatialDims; ++i) { + inPerms.push_back(i + 2); + weightPerms.push_back(i + 2); + } + inPerms.push_back(1); + weightPerms.append({1, 0}); + + paddedInput = + transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); + weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); + outputTensor = + transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); + conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); + + llvm::SmallVector outPerms; + outPerms.push_back(0); + outPerms.push_back(inPerms.size() - 1); + for (size_t i = 0; i < numSpatialDims; ++i) { + outPerms.push_back(i + 1); + } + conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); + break; + } default: return rewriter.notifyMatchFailure( op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; - llvm::SmallVector outPerms; - outPerms.push_back(0); - outPerms.push_back(inPerms.size() - 1); - for (size_t i = 0; i < numSpatialDims; ++i) { - outPerms.push_back(i + 1); - } - conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); - Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = diff --git a/test/Conversion/TorchToLinalg/convolution.mlir b/test/Conversion/TorchToLinalg/convolution.mlir index 3023c0ba6d8a..480b1eeb9ed2 100644 --- a/test/Conversion/TorchToLinalg/convolution.mlir +++ b/test/Conversion/TorchToLinalg/convolution.mlir @@ -24,12 +24,8 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128] // CHECK: %[[c7:.*]] = arith.constant 7 : i32 // CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor // CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor -// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor) -// CHECK-SAME: permutation = [0, 2, 3, 1] -// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor) -// CHECK-SAME: permutation = [2, 3, 1, 0] -// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} -// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) +// CHECK: %[[conv:.*]] = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} +// CHECK-SAME: ins(%[[input]], %[[weight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) // CHECK-SAME: outs(%[[convout:.*]] : tensor) -> tensor func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %false = torch.constant.bool false From 55ff110dc29cab7e2495ccdbec9a60512c29c665 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 23 Oct 2024 03:08:55 -0500 Subject: [PATCH 02/19] [MLIR][TORCH] Only unroll prim loop-like ops within a `torch.shape.calculate` region (#3812) Reports a match failure for the pattern `FullyUnrollPrimLoop` when the loop op is not in a region defined by a `torch.shape.calculate` op. This is needed to avoid unrolling prim loops generated by ONNX IR, since we are applying shape refinement in the `torch-onnx-to-torch-backend-pipeline` introduced in fa4794d . See also the discussion in --- .../SimplifyAbstractInterpCalculationsUtils.cpp | 9 ++++++--- .../Torch/simplify-shape-calculations.mlir | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index f1ebeb307976..d599fd5369f4 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -32,9 +32,6 @@ class FoldPrimUncheckedCastOp : public OpRewritePattern { } // namespace namespace { -// TODO: Only unroll inside the shape calculation region. -// Maybe do this by only applying patterns and folding greedily on the ops -// inside the region + the shape.calculate op itself? class FullyUnrollPrimLoopOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -42,6 +39,12 @@ class FullyUnrollPrimLoopOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); + // Only unroll loops if they are contained in a shape calculate region. + Region *region = op->getParentRegion(); + Operation *parentOp = region->getParentOp(); + if (!parentOp || !isa(parentOp)) + return rewriter.notifyMatchFailure( + op, "Loop is not contained in a shape calculation region."); if (!op.isForLike()) return rewriter.notifyMatchFailure(op, "Loop is not for-like"); int64_t maxTripCount; diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index 59884616f13f..af96e108efbd 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -152,6 +152,23 @@ func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch return %0 : !torch.vtensor } +// CHECK-LABEL: func.func @fully_unroll_prim_loop$outside_region( +// CHECK: %[[LOOP:.*]] = torch.prim.Loop +func.func @fully_unroll_prim_loop$outside_region(%arg0: !torch.vtensor, %arg1: !torch.list, %arg2: !torch.int) -> !torch.vtensor { + %true = torch.constant.bool true + %0 = torch.prim.Loop %arg2, %true, init(%arg0) { + ^bb0(%arg3: !torch.int, %arg4: !torch.vtensor): + %1 = torch.shape.calculate { + torch.shape.calculate.yield %arg4 : !torch.vtensor + } shapes { + torch.prim.Print(%arg3) : !torch.int + torch.shape.calculate.yield.shapes %arg1 : !torch.list + } : !torch.vtensor + torch.prim.Loop.condition %true, iter(%1 : !torch.vtensor) + } : (!torch.int, !torch.bool, !torch.vtensor) -> !torch.vtensor + return %0 : !torch.vtensor +} + // CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.int, From 2f9a68cc1e1af69cd6d339bc694e707e14758094 Mon Sep 17 00:00:00 2001 From: lingzhiz1998 Date: Wed, 23 Oct 2024 21:01:20 +0800 Subject: [PATCH 03/19] Add canonicalization pattern for maxpool3d with indices op (#3704) As discussed in https://github.com/llvm/torch-mlir/pull/3652, we should replace maxpool3dwithindices with maxpool3d if indices have no user. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 44 ++++++++++++++++--- .../build_tools/torch_ods_gen.py | 3 +- test/Dialect/Torch/canonicalize.mlir | 18 ++++++++ 4 files changed, 58 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c3e0141530e4..de87fb46b0c7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7352,6 +7352,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", printDefaultTorchOp(printer, *this, 6, 2); } }]; + let hasCanonicalizer = 1; } def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 97fc5494621b..a583ccfa4cb7 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5188,18 +5188,38 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// AtenMaxPool2dWithIndicesOp +// AtenMaxPoolWithIndicesOp //===----------------------------------------------------------------------===// -void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) { +namespace { + +template struct MaxPoolWithoutIndices { + using type = OpTy; +}; + +template <> struct MaxPoolWithoutIndices { + using type = AtenMaxPool2dOp; +}; + +template <> struct MaxPoolWithoutIndices { + using type = AtenMaxPool3dOp; +}; + +} // namespace + +template +struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern { + SimplifyMaxPoolWithIndices(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult + matchAndRewrite(OpTy op, mlir::PatternRewriter &rewriter) const override { if (!op.getResult1().use_empty()) { return rewriter.notifyMatchFailure( - op, "result1 of MaxPool2dWithIndices should be unused"); + op, "result1 of MaxPoolWithIndices should be unused"); } - Value result = rewriter.create( + Value result = rewriter.create::type>( op->getLoc(), op.getResult0().getType(), op.getSelf(), op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), op.getCeilMode()); @@ -5207,7 +5227,17 @@ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( op.getResult0().replaceAllUsesWith(result); rewriter.eraseOp(op); return success(); - }); + } +}; + +void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add>(context); +} + +void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add>(context); } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e5dcc913527f..4038346d5ea9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -636,7 +636,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( - "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", + has_canonicalizer=True, ) emit( "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index f13bf60cb15b..f63d313af575 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3136,6 +3136,24 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor // ----- +// CHECK-LABEL: @torch.aten.max_pool3d_with_indices$canonicalize( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> { +// CHECK: %[[RET:.*]] = torch.aten.max_pool3d %[[ARG]] +// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56,56],f32> +func.func @torch.aten.max_pool3d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool3d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56,56],f32>, !torch.vtensor<[10,64,56,56,56],si64> + return %result0 : !torch.vtensor<[10,64,56,56,56],f32> +} + +// ----- + // CHECK-LABEL: @torch.aten.clone$no_fold( func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) { // CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor From d6feb2179c552c4b88bc3710d7a7e870eeea1734 Mon Sep 17 00:00:00 2001 From: Sriram Kumar <154416395+sriram-siloai@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:34:50 +0530 Subject: [PATCH 04/19] Added support for Maxpool (Autopad) (#3774) Added autopad. and passed 3 tests test_maxpool_2d_precomputed_same_upper test_maxpool_2d_same_lower' test_maxpool_2d_same_upper Address : https://github.com/nod-ai/SHARK-ModelDev/issues/843 2 attributes yet to complete : storage_order, indices output --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 32 +++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 80 +++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 168040d9b289..a7f707cae9bb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1087,9 +1087,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return rewriter.notifyMatchFailure(binder.op, "auto_pad bind failure"); - if (autoPad != "NOTSET") - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); Torch::ValueTensorType resultTypeOut; Value operand; @@ -1136,6 +1133,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); + // set default padding if (padding.empty()) padding.resize(spatial, 0); if (strides.empty()) @@ -1143,6 +1141,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (dilations.empty()) dilations.resize(spatial, 1); + auto inputTensorType = cast(operand.getType()); + + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (autoPad != "NOTSET" && autoPad != "VALID") { + const bool isSameLower = autoPad == "SAME_LOWER"; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatial); + for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) { + const int64_t dilatedKernelSize = + dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; + int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / + strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = + isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatial + dimIdx] = totalPad - padding[dimIdx]; + } + } + // If the padding is symmetric we can push the padding operation to the // torch operator. if (padding.size() == static_cast(2 * spatial)) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 21be2a65f4a6..d567db79fdf8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -730,6 +730,86 @@ func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch return %0 : !torch.vtensor<[1,64,56,56],f32> } +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_lower +func.func @test_maxpool_2d_same_lower(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int1]], %[[int0]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_upper +func.func @test_maxpool_2d_same_upper(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_precomputed_same_upper +func.func @test_maxpool_2d_precomputed_same_upper(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64}{ + // CHECK: %[[int3:.*]] = torch.constant.int 3 + // CHECK: %[[int3_0:.*]] = torch.constant.int 3 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int3]], %[[int3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int1]], %[[int1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_3:.*]] = torch.constant.int 1 + // CHECK: %[[int1_4:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_3]], %[[int1_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC4:.*]] = torch.aten.max_pool2d %arg0, %[[list0]], %[[list2]], %[[list1]], %[[list3]], %[[FALSE]] : !torch.vtensor<[1,1,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,3,3],f32> +%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> +return %0 : !torch.vtensor<[1,1,3,3],f32> +} + // ----- From 1259e8a00a86231ff608ab1d19cd1ad9806fcd2b Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Thu, 24 Oct 2024 12:09:00 -0500 Subject: [PATCH 05/19] Add Some Folders For Small Reshape Ops (#3813) ### Changes 1. Folders for view-like ops: `aten.view`, `aten.flatten.using_ints`, and `aten.unflatten.int` 2. Folder for transpose 3. Extended support for the `aten.slice.Tensor` op folder to include negative strides. ### Motivation The biggest motivation for this patch is to fold the extremely convoluted ir that gets generated when exporting a pytorch model with an `aten.pad` op to ONNX, then re-importing and lowering back to torch. For example, the verbose output of the e2e test `PadModule_basic` with `-c onnx`: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %none = torch.constant.none %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> %2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> %5 = torch.operator "onnx.Reshape"(%3, %4) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> %11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %13 = torch.operator "onnx.Reshape"(%11, %12) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> %14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor} : () -> !torch.vtensor<[],f32> %16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> return %16 : !torch.vtensor<[?,?,?,?],f32> } } {-# dialect_resources: { builtin: { _: "0x080000000400000000000000", __1: "0x080000000000000000000000010000000000000002000000000000000300000000000000", __2: "0x080000000000000000000000", __3: "0x08000000FFFFFFFFFFFFFFFF0200000000000000", __4: "0x080000000000000000000000", __5: "0x08000000FFFFFFFFFFFFFFFF", __6: "0x080000000100000000000080", __7: "0x08000000FFFFFFFFFFFFFFFF", __8: "0x08000000FFFFFFFFFFFFFFFF", __9: "0x080000000000C03F" } } #-} ``` Get's converted to the torch IR: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %float1.500000e00 = torch.constant.float 1.500000e+00 %int-9223372036854775807 = torch.constant.int -9223372036854775807 %int-1 = torch.constant.int -1 %int7 = torch.constant.int 7 %int6 = torch.constant.int 6 %int5 = torch.constant.int 5 %int3 = torch.constant.int 3 %int8 = torch.constant.int 8 %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int4 = torch.constant.int 4 %int0 = torch.constant.int 0 %0 = torch.vtensor.literal(dense<[0, 1, 2, 3, 0, 0, 0, 0]> : tensor<8xsi64>) : !torch.vtensor<[8],si64> %1 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> %3 = torch.aten.slice.Tensor %2, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> %4 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> %5 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list %6 = torch.aten.view %4, %5 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> %7 = torch.aten.slice.Tensor %6, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int %9 = torch.aten.slice.Tensor %6, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int %11 = torch.aten.slice.Tensor %6, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int %13 = torch.aten.slice.Tensor %6, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int %15 = torch.aten.slice.Tensor %6, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %16 = torch.aten.item %15 : !torch.vtensor<[1],si64> -> !torch.int %17 = torch.aten.slice.Tensor %6, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %18 = torch.aten.item %17 : !torch.vtensor<[1],si64> -> !torch.int %19 = torch.aten.slice.Tensor %6, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %20 = torch.aten.item %19 : !torch.vtensor<[1],si64> -> !torch.int %21 = torch.aten.slice.Tensor %6, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int %23 = torch.prim.ListConstruct %14, %22, %12, %20, %10, %18, %8, %16 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %24 = torch.aten.constant_pad_nd %arg0, %23, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?,?,?],f32> return %24 : !torch.vtensor<[?,?,?,?],f32> } } ``` ***All of these operations are useless***. It is literally the result of needing to reverse (and change the lexicographic order hierarchy of) padding ints provided via torch vs. ONNX pad ops, which is then subsequently UNDONE by our ONNX->Torch lowering (represented in the ordering of the generated list construct). With the added folders in this patch, the torch IR becomes: ``` module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %float1.500000e00 = torch.constant.float 1.500000e+00 %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %1 = torch.aten.constant_pad_nd %arg0, %0, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?,?,?],f32> return %1 : !torch.vtensor<[?,?,?,?],f32> } } ``` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 3 + lib/Dialect/Torch/IR/TorchOps.cpp | 123 +++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 14 -- .../build_tools/torch_ods_gen.py | 8 +- test/Dialect/Torch/canonicalize.mlir | 76 +++++++++++ 5 files changed, 200 insertions(+), 24 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index de87fb46b0c7..36b2243afbba 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8080,6 +8080,7 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ @@ -9672,6 +9673,7 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ @@ -9696,6 +9698,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a583ccfa4cb7..97b724984310 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -30,6 +30,24 @@ using namespace mlir::torch::Torch; // Utilities //===----------------------------------------------------------------------===// +OpFoldResult genericViewLikeFold(Attribute self, Type resultType) { + auto selfAttr = dyn_cast_or_null(self); + if (!selfAttr) + return nullptr; + + auto resultTy = dyn_cast_or_null(resultType); + if (!resultTy || !resultTy.areAllSizesKnown()) + return nullptr; + + if (selfAttr.isSplat()) { + return SplatElementsAttr::get(resultTy.toBuiltinTensor(), + selfAttr.getSplatValue()); + } + return DenseElementsAttr::get( + resultTy.toBuiltinTensor(), + llvm::to_vector(selfAttr.getValues())); +} + Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, Location loc, Value value, Type desiredType, @@ -1049,6 +1067,8 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { + if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType())) + return genericFold; auto inputType = dyn_cast(getOperand(0).getType()); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; @@ -2236,10 +2256,22 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenFlattenUsingIntsOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenFlattenUsingIntsOp::fold(FoldAdaptor adaptor) { + return genericViewLikeFold(adaptor.getSelf(), getType()); +} + //===----------------------------------------------------------------------===// // AtenUnflattenIntOp //===----------------------------------------------------------------------===// +OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) { + return genericViewLikeFold(adaptor.getSelf(), getType()); +} + void AtenUnflattenIntOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { // if there are only two sizes and one of them is statically 1, then convert @@ -3722,6 +3754,69 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// AtenTransposeIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTransposeIntOp::fold(FoldAdaptor adaptor) { + // first check for no-op + IntegerAttr dim0 = dyn_cast_or_null(adaptor.getDim0()); + IntegerAttr dim1 = dyn_cast_or_null(adaptor.getDim1()); + if (!dim0 || !dim1) + return nullptr; + int64_t _dim0 = dim0.getValue().getSExtValue(); + int64_t _dim1 = dim1.getValue().getSExtValue(); + auto selfTy = dyn_cast(getSelf().getType()); + if (!selfTy || !selfTy.hasSizes()) + return nullptr; + int64_t rank = selfTy.getSizes().size(); + _dim0 = toPositiveDim(_dim0, rank); + _dim1 = toPositiveDim(_dim1, rank); + if (!isValidDim(_dim0, rank) || !isValidDim(_dim1, rank)) + return nullptr; + // if dims are the same, return self + if (_dim0 == _dim1) + return getSelf(); + + // We set a maximum folding size of 16. This is a reasonable upper limit + // for shape computations. + constexpr int64_t kMaxFoldSize = 16; + auto self = dyn_cast_or_null(adaptor.getSelf()); + if (!self || self.getNumElements() > kMaxFoldSize) + return nullptr; + auto resultTy = dyn_cast(getType()); + if (!selfTy || !resultTy || !selfTy.areAllSizesKnown()) + return nullptr; + if (self.isSplat()) + return SplatElementsAttr::get(resultTy.toBuiltinTensor(), + self.getSplatValue()); + + // TODO: add support for rank != 2 + if (rank != 2) + return nullptr; + + ArrayRef sizes = selfTy.getSizes(); + auto values = llvm::to_vector(self.getValues()); + // reordered[i] = Trans[i//sizes[0], i % sizes[0]] = Self[i % sizes[0], + // i//sizes[0]] = values[(i % sizes[0])*sizes[1] + (i//sizes[0])]. + // e.g., Self size = [4,2]; Trans size = [2,4]. + // reindex(i) = (i % 4)*2 + (i // 4) . + // i = 0 -> Trans[0,0] -> Self[0,0] -> 0 . + // i = 1 -> Trans[0,1] -> Self[1,0] -> 2 . + // i = 2 -> Trans[0,2] -> Self[2,0] -> 4 . + // i = 3 -> Trans[0,3] -> Self[3,0] -> 6 . + // i = 4 -> Trans[1,0] -> Self[0,1] -> 1 . + // i = 5 -> Trans[1,1] -> Self[1,1] -> 3 . + auto reindex = [&](int64_t i) { + return (i % sizes[0]) * sizes[1] + (i / sizes[0]); + }; + SmallVector reordered; + for (int64_t i = 0; i < self.getNumElements(); i++) { + reordered.push_back(values[reindex(i)]); + } + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), reordered); +} + //===----------------------------------------------------------------------===// // AtenCatOp //===----------------------------------------------------------------------===// @@ -3898,15 +3993,18 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { // Fold the slice if the output tensor is relatively small, currently // coded to 16: constexpr int64_t kMaxFold = 16; - if (input && start && step && dim && count <= kMaxFold) { + if (input && start && step && dim && end && count <= kMaxFold) { int64_t begin = start.getValue().getSExtValue(); int64_t limit = end.getValue().getSExtValue(); int64_t stride = step.getValue().getSExtValue(); - if (stride < 1) - return nullptr; begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin; limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit; + limit = limit < 0 ? -1 : limit; limit = std::min(limit, inType.getSizes()[dimInt]); + bool validIterArgs = + (stride > 0 && begin < limit) || (stride < 0 && begin > limit); + assert(validIterArgs && + "aten.slice.Tensor iteration args are statically invalid."); int64_t inputRank = inType.getSizes().size(); llvm::SmallVector inputStrides(inputRank, 1); @@ -3919,10 +4017,21 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) { if (currDim >= inputRank) return; - size_t _begin = (currDim == dimInt) ? begin : 0; - size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim]; - size_t _stride = (currDim == dimInt) ? stride : 1; - for (size_t i = _begin; i < _limit; i += _stride) { + int64_t _stride = (currDim == dimInt) ? stride : 1; + int64_t _begin = (currDim == dimInt) ? begin : 0; + int64_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim]; + // ensure that the limit is reached exactly (even with negative strides) + // E.g., with begin = 0, limit = 10, stride = 3, we modify limit to be 11 + // = 10 + (10-0) % 3 . + // E.g., with begin = 8, limit = -1, stride = -2, limit becomes -2 = -1 + + // (-1-8) % (-2) - stride = -1 + 1 - 2 = -2 . + // Note: cpp uses true math remainder "n % d = least positive int, x, such + // that d divides (n - x)" + int64_t limit_rem = (_limit - _begin) % _stride; + limit_rem = + (_stride > 0 || limit_rem == 0) ? limit_rem : limit_rem - _stride; + _limit += limit_rem; + for (int64_t i = _begin; std::abs(_limit - i) > 0; i += _stride) { if (currDim == inputRank - 1) { values.push_back(input.getValues()[currOffset + i]); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index dce3dea1ee03..ab5c54b762a8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2677,20 +2677,6 @@ "MultinomialModule2D_basic", "MultinomialModule2D_F32", "PixelShuffleModuleStaticRank4Float32_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", "SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyNegative_Module_basic", "SliceCopyNonZeroDim_Module_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 4038346d5ea9..31984d727048 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -684,7 +684,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") - emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") + emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True) emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True) emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") @@ -769,9 +769,11 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) - emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") + emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)", has_folder=True) emit( - "aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True + "aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, ) emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index f63d313af575..90b4e103c4fb 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1682,6 +1682,82 @@ func.func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[? return %1 : !torch.tensor<[?],f32> } +// CHECK-LABEL: func.func @torch.aten.view$fold_splat( +// CHECK: %[[SPLAT:.*]] = torch.vtensor.literal(dense<2> : tensor<2x4x1xsi64>) : !torch.vtensor<[2,4,1],si64> +// CHECK: return %[[SPLAT]] : !torch.vtensor<[2,4,1],si64> +func.func @torch.aten.view$fold_splat() -> !torch.vtensor<[2,4,1],si64> { + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<2> : tensor<8xsi64>) : !torch.vtensor<[8],si64> + %1 = torch.prim.ListConstruct %int2, %int4, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,4,1],si64> + return %2 : !torch.vtensor<[2,4,1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.view$fold_literal( +// CHECK: %[[LITERAL:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [ +// CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7]]]> : tensor<1x4x2xsi64>) : !torch.vtensor<[1,4,2],si64> +// CHECK: return %[[LITERAL]] : !torch.vtensor<[1,4,2],si64> +func.func @torch.aten.view$fold_literal() -> !torch.vtensor<[1,4,2],si64> { + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<[0,1,2,3,4,5,6,7]> : tensor<8xsi64>) : !torch.vtensor<[8],si64> + %1 = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,4,2],si64> + return %2 : !torch.vtensor<[1,4,2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_literal( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xsi64>) : !torch.vtensor<[2,4],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[2,4],si64> +func.func @torch.aten.transpose.int$fold_literal() -> !torch.vtensor<[2,4],si64> { + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0,1],[2,3],[4,5],[6,7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.transpose.int %0, %int-1, %int0 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4], si64> + return %1 : !torch.vtensor<[2,4],si64> +} + +// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_noop( +// CHECK: return %arg0 : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.transpose.int$fold_noop(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int-1 = torch.constant.int -1 + %int3 = torch.constant.int 3 + %0 = torch.aten.transpose.int %arg0, %int-1, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.slice.Tensor$flip_slice_fold( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [6, 7], [4, 5], [2, 3], [0, 1]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[4,2],si64> +func.func @torch.aten.slice.Tensor$flip_slice_fold() -> !torch.vtensor<[4,2],si64> { + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.slice.Tensor$negative_two_stride_fold( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [6, 7], [2, 3]]> : tensor<2x2xsi64>) : !torch.vtensor<[2,2],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[2,2],si64> +func.func @torch.aten.slice.Tensor$negative_two_stride_fold() -> !torch.vtensor<[2,2],si64> { + %int-5 = torch.constant.int -5 + %int-1 = torch.constant.int -1 + %int-2 = torch.constant.int -2 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-5, %int-2 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],si64> + return %1 : !torch.vtensor<[2,2],si64> +} + // CHECK-LABEL: func.func @torch.aten.div.float$fold_zero_dividend( // CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00 // CHECK: return %[[CST0]] : !torch.float From 76209db5a5817e098cfced7f065a0f54e6b09d13 Mon Sep 17 00:00:00 2001 From: Felix Schneider Date: Thu, 24 Oct 2024 21:59:58 +0200 Subject: [PATCH 06/19] Update quantized matmul tests to DQ/Q format supported by fx_importer (#3815) Continuation of https://github.com/llvm/torch-mlir/pull/3809 for the matmul tests. --- projects/pt1/e2e_testing/xfail_sets.py | 9 -- .../torch_mlir_e2e_test/test_suite/matmul.py | 130 ++++++++++-------- 2 files changed, 75 insertions(+), 64 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ab5c54b762a8..553a27924da0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -394,15 +394,6 @@ "AtenIntBoolOpModule_basic", "AtenIntMM_basic", "AtenItemFpOpModule_basic", - "AtenMatmulQMixedSigni8Transpose_basic", - "AtenMatmulQMixedSigni8_basic", - "AtenMatmulQint8MV_basic", - "AtenMatmulQint8_basic", - "AtenMatmulQint8VM_basic", - "AtenMatmulQint8VV_basic", - "AtenMmQMixedSigni8_basic", - "AtenMmQint8_basic", - "AtenMmQuint8_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 40e6a735901d..17240cf953df 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils): # ============================================================================== +# For DQ-Q fake quantization ops +import torch.ao.quantization.fx._decomposed class AtenMmQint8(torch.nn.Module): @@ -352,12 +354,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQint8()) @@ -384,12 +388,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.199, 65, 0, 255, torch.uint8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0215, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQuint8()) @@ -416,12 +422,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQMixedSigni8()) @@ -475,12 +483,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VM()) @@ -505,12 +515,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VV()) @@ -535,12 +547,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8MV()) @@ -565,12 +579,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8()) @@ -597,12 +613,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8()) @@ -629,13 +647,15 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qy = torch.transpose(qy, 1, 2) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + y = torch.transpose(y, 1, 2) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose()) From ad9dfe974ee12c4a56c0047eaabfb9e7ad642b28 Mon Sep 17 00:00:00 2001 From: Dmitry Babokin Date: Fri, 25 Oct 2024 00:42:08 -0700 Subject: [PATCH 07/19] Fix clang warning about printf format (#3814) Compiling with clang 16.0 on macOS I have warnings about incorrect printf format (see below). Values to be printed are `int64_t`, but they are printed with `%zu` and `%ld`, which are not portable way to print this type. ``` <...>/torch-mlir/test/CAPI/torch.c:52:3: warning: format specifies type 'size_t' (aka 'unsigned long') but the argument has type 'int64_t' (aka 'long long') [-Wformat] 52 | DEFINE_CHECK(NonValueTensor) | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~ <...>/torch-mlir/test/CAPI/torch.c:37:13: note: expanded from macro 'DEFINE_CHECK' 36 | fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \ | ~~~ 37 | torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \ | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ :78:1: note: expanded from here 78 | torchMlirTorchNonValueTensorTypeGetRank | ^ <...>/torch-mlir/test/CAPI/torch.c:52:3: warning: format specifies type 'long' but the argument has type 'int64_t' (aka 'long long') [-Wformat] 52 | DEFINE_CHECK(NonValueTensor) | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~ <...>/torch-mlir/test/CAPI/torch.c:42:15: note: expanded from macro 'DEFINE_CHECK' 41 | fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \ | ~~~ 42 | TTT##Sizes[i]); \ | ^~~~~~~~~~~~~ :85:1: note: expanded from here 85 | NonValueTensorSizes | ^ <...>/torch-mlir/test/CAPI/torch.c:53:3: warning: format specifies type 'size_t' (aka 'unsigned long') but the argument has type 'int64_t' (aka 'long long') [-Wformat] 53 | DEFINE_CHECK(ValueTensor) | ^~~~~~~~~~~~~~~~~~~~~~~~~ <...>/torch-mlir/test/CAPI/torch.c:37:13: note: expanded from macro 'DEFINE_CHECK' 36 | fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \ | ~~~ 37 | torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \ | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ :112:1: note: expanded from here 112 | torchMlirTorchValueTensorTypeGetRank | ^ <...>/torch-mlir/test/CAPI/torch.c:53:3: warning: format specifies type 'long' but the argument has type 'int64_t' (aka 'long long') [-Wformat] 53 | DEFINE_CHECK(ValueTensor) | ^~~~~~~~~~~~~~~~~~~~~~~~~ <...>/torch-mlir/test/CAPI/torch.c:42:15: note: expanded from macro 'DEFINE_CHECK' 41 | fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \ | ~~~ 42 | TTT##Sizes[i]); \ | ^~~~~~~~~~~~~ :119:1: note: expanded from here 119 | ValueTensorSizes | ^ 4 warnings generated. ``` --- test/CAPI/torch.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/CAPI/torch.c b/test/CAPI/torch.c index d42cf96d554c..3d1308f08b25 100644 --- a/test/CAPI/torch.c +++ b/test/CAPI/torch.c @@ -33,12 +33,12 @@ static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes, bool TTT##hasDtype = torchMlirTorch##TTT##TypeHasDtype(TTT##Type); \ fprintf(stderr, #TTT "Type %s hasDtype: %d\n", testName, TTT##hasDtype); \ if (TTT##hasSizes) { \ - fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \ + fprintf(stderr, #TTT "Type %s rank: %" PRId64 "\n", testName, \ torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \ int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \ torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \ for (int i = 0; i < numSizes; ++i) { \ - fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \ + fprintf(stderr, #TTT "Type %s pos %d size: %" PRId64 "\n", testName, i, \ TTT##Sizes[i]); \ } \ } \ From 54d9e2401376e7eb2c6c219e3b3555f45f8b2635 Mon Sep 17 00:00:00 2001 From: Andrija Bosnjakovic Date: Fri, 25 Oct 2024 18:01:05 +0200 Subject: [PATCH 08/19] [TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 84 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 57 +++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 132 ++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 2 + projects/pt1/e2e_testing/xfail_sets.py | 22 +++ .../build_tools/abstract_interp_lib_gen.py | 24 +++ .../build_tools/torch_ods_gen.py | 4 + .../test_suite/backprop.py | 161 ++++++++++++++++++ .../test_suite/elementwise.py | 82 +++++++++ 9 files changed, 568 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 36b2243afbba..206d70ffbfa9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ }]; } +def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ AllowsTypeRefinement, HasValueSemantics, @@ -16814,6 +16869,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + Torch_BoolType:$self_is_result + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f2963f7c803d..46cb3e6b7efe 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7285,6 +7289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12055,6 +12063,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12247,6 +12263,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9b24d0e959f3..1fefb59a4cac 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3489,6 +3489,59 @@ class DecomposeAtenLeakyReluBackwardOp }; } // namespace +namespace { +class DecomposeAtenRreluWithNoiseBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value gradOutput = op.getGradOutput(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, + "training should be a bool constant"); + } + + bool selfIsResult = false; + if (!matchPattern(op.getSelfIsResult(), + m_TorchConstantBool(&selfIsResult)) || + selfIsResult) + return rewriter.notifyMatchFailure( + op, "unimplemented: self_is_result should be false"); + + double lower, upper; + if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) || + !matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) { + return rewriter.notifyMatchFailure( + op, "lower and upper should be float constants"); + } + + if (training && (upper - lower > 0.000001)) { + Value rreluWithNoiseBackwardOutput = + rewriter.create(loc, resType, gradOutput, noise); + rewriter.replaceOp(op, rreluWithNoiseBackwardOutput); + } else { + double negative_slope = (upper + lower) / 2; + Value cstNegativeSlope = rewriter.create( + loc, rewriter.getF64FloatAttr(negative_slope)); + rewriter.replaceOpWithNewOp( + op, resType, gradOutput, self, cstNegativeSlope, + op.getSelfIsResult()); + } + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenPreluOp : public OpRewritePattern { public: @@ -3588,6 +3641,82 @@ class DecomposeAtenRreluOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenRreluWithNoiseOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "training should be a constant"); + } + + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value constantOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value constantTwoFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + + Value alpha; + if (training) { + Value none = rewriter.create(loc); + Value emptyTensor = rewriter.create( + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + alpha = rewriter.create(loc, resType, emptyTensor, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); + } else { + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); + } + + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, self); + + Value scaledSelf; + if (training) { + scaledSelf = rewriter.create(loc, resType, self, alpha); + auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(), + rewriter.getI1Type()); + Value oneTensor = + createRank0Tensor(rewriter, loc, resType, constantOneFloat); + Value not_positive = rewriter.create( + loc, boolResType, self, constantZeroFloat); + noise = rewriter.create(loc, resType, not_positive, + alpha, oneTensor); + } else { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } + + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOneFloat); + rewriter.replaceOp(op, rreluOutput); + return success(); + } +}; +} // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { @@ -9924,6 +10053,9 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ebc43faa595c..feb63db0b324 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -498,6 +498,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 553a27924da0..e370a1d8b73d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1207,6 +1207,10 @@ "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "ElementwiseRemainderTensorModule_Float_basic", "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", @@ -2106,6 +2110,7 @@ "ElementwiseReciprocalModule_basic", "ElementwiseRelu6Module_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", @@ -2238,6 +2243,10 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ResNet18StaticModule_basic", @@ -2436,6 +2445,10 @@ "ViewSizeFromOtherTensor_basic", "RenormModuleFloat32NegativeDim_basic", "RenormModuleFloat32_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", } ) - { ### Test failing in make_fx_tosa but not in tosa @@ -2854,6 +2867,10 @@ "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSgnModule_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", @@ -3002,6 +3019,11 @@ "ReduceL1NormComplexModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", + "RreluWithNoiseForwardBackwardModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d632e9815443..1cb9678ec5d5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -298,6 +298,9 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]: return upstream_shape_functions.unary(grad_output) +def aten〇rrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]: + return upstream_shape_functions.unary(grad_output) + def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -634,6 +637,9 @@ def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -3126,6 +3132,15 @@ def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES]) +def aten〇rrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3293,6 +3308,15 @@ def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, flo assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()})) +def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + noise_rank, noise_dtype = noise_rank_dtype + assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) + assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype) + assert self_rank == noise_rank + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 31984d727048..17f7faa10f22 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -302,6 +302,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", + "aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", "aten::celu : (Tensor, Scalar) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", @@ -1171,6 +1172,9 @@ def emit_with_mutating_variants(key, **kwargs): "aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)" ) emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + emit( + "aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)" + ) # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index e209d15b2b0b..5e6e093902c4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -322,3 +322,164 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule()) def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule()) +def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule()) +def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule()) +def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule()) +def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseForwardBackwardModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + res = torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.4, + upper=0.6, + training=True, + self_is_result=False, + ) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule()) +def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils): + grad = tu.rand(256, 244) + input = tu.rand(256, 244, low=-1.0, high=1.0) + noise = tu.rand(256, 244) + torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True) + module.forward(grad, input, noise) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index ed5254353fd2..a62b901a91ec 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1179,6 +1179,88 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) +def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) +def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule()) +def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)]) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule()) +def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__() From 2b01f8b7f3cca87c3dc9c75edd91397803e9f6d4 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 25 Oct 2024 18:37:19 -0400 Subject: [PATCH 09/19] [Tosa] : Add support for negative indices in index.tensor and index.Tensor_hacked_twin for TorchToTosa lowering. (#3790) 1. Negative indices for tensor indexing is handled by wrapping around the index values by checking their values at run time. Without the fix, there was a runtime error. 2. Added a lit test to lock down the behavior. 3. Updated the `xfails_set` for `fx_importer_tosa` config to lockdown the behavior with e2e test as well. "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY." --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 83 +++++++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 4 +- test/Conversion/TorchToTosa/basic.mlir | 32 +++++++++ 3 files changed, 81 insertions(+), 38 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e5f4fea4f46c..b6dbdc2c7b8c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +Value wrapNegativeIndices(Value index, int maxIndex, Operation *op, + ConversionPatternRewriter &rewriter) { + + auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto maxIndexValue = + tosa::getConstTensor(rewriter, op, maxIndex, {}).value(); + + auto indexType = dyn_cast(index.getType()); + + auto wrappedIndicesOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), indexType, maxIndexValue, index); + auto boolType = indexType.clone(rewriter.getIntegerType(1)); + auto isNegativeIndices = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, index); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), + indexType, isNegativeIndices, + wrappedIndicesOp, index); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, @@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()); + Operation *indicesTf; + // Support for multiple indexes if (indexTensors.size() > 1) { // t[i, i] @@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( index); } + index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, + rewriter); // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indiceShapeOneDim; for (auto shape : indexShape) { @@ -4299,49 +4322,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesShapeConcat = indexesShape[0]; uint64_t lastDim = indexesRank[0]; indicesShapeConcat.push_back(indicesTfConcatTensors.size()); - auto indicesTf = tosa::CreateOpAndInfer( + indicesTf = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), indicesTfConcatTensors, lastDim); - if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); - } - // do the tf gathernp algorithm with tf style indices as input. - auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + } else { - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert GatherNdOp fail for index tensor."); + // Single index + auto index = indexTensors[0]; + auto indexType = dyn_cast(index.getType()); + auto indexShape = indexType.getShape(); + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), + index); } - rewriter.replaceOp(op, {result.value()}); - return success(); - } + index = + wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter); - // Support for multiple index - auto index = indexTensors[0]; - auto indexType = dyn_cast(index.getType()); - auto indexShape = indexType.getShape(); - // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } - - // Expand last dim of index to tf indices [2,3] -> [2,3,1] - SmallVector indicesShape; - for (auto shape : indexShape) { - indicesShape.push_back(shape); + // Expand last dim of index to tf indices [2,3] -> [2,3,1] + SmallVector indicesShape; + for (auto shape : indexShape) { + indicesShape.push_back(shape); + } + indicesShape.push_back(1); + indicesTf = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, + rewriter.getDenseI64ArrayAttr(indicesShape)); } - indicesShape.push_back(1); - auto indicesTf = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, - rewriter.getDenseI64ArrayAttr(indicesShape)); if (!indicesTf) { return rewriter.notifyMatchFailure(op, @@ -4349,7 +4362,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // do the tf gathernp algorithm with tf style indices as input. auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + indicesTf->getResult(0)); if (!result) { return rewriter.notifyMatchFailure( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e370a1d8b73d..82ca24443162 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1698,7 +1698,6 @@ "ArangeStartOutModule_basic", "ScatterSrcStaticModule_basic", # Runtime op verification: Out of bounds access - "IndexTensorNegativeIndexModule_basic", "ReduceAllDimEmpty_basic", } @@ -1706,7 +1705,6 @@ "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -2162,6 +2160,7 @@ "HardswishRandomModule_basic", "HardtanhBackward_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorNegativeIndexModule_basic", "IndexTensorStaticModule_basic", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", @@ -3635,7 +3634,6 @@ "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "IndexSelectRank0IdxModule_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index e412bb390c35..ed6f909c4a1b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2131,3 +2131,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t %0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32> return %0 : !torch.vtensor<[2,3,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64> +// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64> + +func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> + } From 9ab2a150f20abbddcb291b9437d5b2b3506c9ace Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 30 Oct 2024 20:18:24 +0800 Subject: [PATCH 10/19] [Torch] emit upsample_bilinear2d(.vec) ops (#3834) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 53 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 22 ++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 +++ .../build_tools/abstract_interp_lib_gen.py | 24 +++++++++ .../build_tools/torch_ods_gen.py | 4 ++ 5 files changed, 109 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 206d70ffbfa9..5ec6a4d1dcf9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14093,6 +14093,59 @@ def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [ }]; } +def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalFloatType:$scales_h, + AnyTorchOptionalFloatType:$scales_w + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 46cb3e6b7efe..1765786be0f6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -11043,6 +11043,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %10 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.list {\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0, %arg1, %arg3) : (!torch.list, !torch.optional>, !torch.optional>) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12576,6 +12590,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 82ca24443162..5686664d39ad 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -531,6 +531,9 @@ "_SoftmaxModule_basic", "UpSampleNearest2dDynamicFactor_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -979,6 +982,9 @@ # materialization callback produced value of incorrect type failed "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } STABLEHLO_PASS_SET = { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1cb9678ec5d5..d9e57d67421c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2342,6 +2342,20 @@ def aten〇upsample_nearest2d〇vec〡shape(input: List[int], output_size: Optio assert scale_factors is not None return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])] +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True) +]) +def aten〇upsample_bilinear2d〡shape(self: List[int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: + return [self[0], self[1], output_size[0], output_size[1]] + +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True, None), + Invocation(TensorOfShape(1, 3, 10, 9), None, True, [2.0, 2.3]), + Invocation(TensorOfShape(1, 3, 5, 6), None, True, [2.5, 1.0]) +]) +def aten〇upsample_bilinear2d〇vec〡shape(input: List[int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> List[int]: + return aten〇upsample_nearest2d〇vec〡shape(input, output_size, scale_factors) + # ============================================================================== # Dtype Functions # ============================================================================== @@ -3570,6 +3584,16 @@ def aten〇upsample_nearest2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], o self_rank, self_dtype = input_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True)) +def aten〇upsample_bilinear2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True, scale_factors=None)) +def aten〇upsample_bilinear2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> int: + self_rank, self_dtype = input_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 17f7faa10f22..311636c820cc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1013,6 +1013,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") + emit( + "aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)" + ) + emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)") emit( "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)" ) From 16b3bd6e6c8fbf166aad51911ef3fb24e7c96858 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 30 Oct 2024 18:56:01 +0530 Subject: [PATCH 11/19] build: manually update PyTorch version and fix CI failure (#3830) This commit sets the PyTorch and TorchVision version to nightly release 2024-10-29. This commit also fixes the CI failure after this commit https://github.com/llvm/torch-mlir/commit/54d9e2401376e7eb2c6c219e3b3555f45f8b2635 got merged. The issue was that the CI checks in the PR were run before the previous roll pytorch update but the PR was actually merged after the roll pytorch update. Hence, the failure was not caught before merging the PR. While exporting the fx_graph through fx_importer for `rrelu` and `rrelu_with_noise` op for train mode, it decomposes the `aten.rrelu_with_noise` op based on the PyTorch decomposition which is the default behavior. However, the decomposition contains an input mutation specifically here https://github.com/pytorch/pytorch/blob/9bbe4a67ad137032add6a3b0b74bda66f5ef83d2/torch/_decomp/decompositions.py#L325, resulting in the runtime failure. This issue would probably be fixed by https://github.com/pytorch/pytorch/pull/138503. Until then, the failing tests are added to the xfail set. Also, after the roll pytorch update following tests started passing for fx_importer, and fx_importer_stablehlo config. - "ElementwiseRreluTrainModule_basic" - "ElementwiseRreluTrainStaticModule_basic" - "ElementwiseRreluWithNoiseTrainModule_basic" - "ElementwiseRreluWithNoiseTrainStaticModule_basic" This commit also updates the dtype check for the `aten.linear` op since the op now expects both the input tensors to have the same dtype. Signed-Off By: Vivek Khandelwal --- projects/pt1/e2e_testing/xfail_sets.py | 18 ++++++++++-------- .../build_tools/abstract_interp_lib_gen.py | 2 +- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5686664d39ad..3881aa145d1c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -420,7 +420,6 @@ "DeformConv2D_basic", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", @@ -446,8 +445,6 @@ "NllLossModuleBackward1DSum_basic", "NllLossModuleBackward1DWeight_basic", "NllLossModuleBackward1D_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", "PowIntFloatModule_basic", @@ -464,7 +461,6 @@ "QuantizedSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", "SortIntListReverse_basic", "SortIntList_basic", @@ -523,6 +519,11 @@ "MeshgridIndexingXY_basic", "Meshgrid_basic", "OneHotModule_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -690,7 +691,6 @@ "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", @@ -792,8 +792,6 @@ "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormalFunctionalModule_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", "PowIntFloatModule_basic", @@ -829,7 +827,6 @@ "ReplicationPad2dModule_left0", "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", - "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScatterReduceFloatMaxModule", @@ -964,6 +961,11 @@ "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d9e57d67421c..36ab8fe2c69f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -5371,7 +5371,7 @@ def aten〇atanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype -@check_dtype_function(_check_two_tensor_op()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype diff --git a/pytorch-hash.txt b/pytorch-hash.txt index f9e0abfabac1..dd4f3a19ad33 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -160d421a40e934ac8183e47f9cbc8618a4bd97dd +c787213d413e85c66bdad0d8c9cde1c5ced34b1b diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index ca065711a140..960ca904e045 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.6.0.dev20241020 +torch==2.6.0.dev20241029 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 608d687cb6d1..901fbd3d9a84 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.20.0.dev20241020 +torchvision==0.20.0.dev20241029 From 6b58c89914c737c40c4066249b8a0de37309f6bd Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:51:06 -0400 Subject: [PATCH 12/19] Remove variable used for only assertion (#3837) Removes a boolean variable that is used only for an assertion, and inlines the condition into the assertion. Signed-off-by: Max Dawkins --- lib/Dialect/Torch/IR/TorchOps.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 97b724984310..84fa405f94fd 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4001,10 +4001,9 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit; limit = limit < 0 ? -1 : limit; limit = std::min(limit, inType.getSizes()[dimInt]); - bool validIterArgs = - (stride > 0 && begin < limit) || (stride < 0 && begin > limit); - assert(validIterArgs && - "aten.slice.Tensor iteration args are statically invalid."); + assert((stride > 0 && begin < limit) || + (stride < 0 && begin > limit) && + "aten.slice.Tensor iteration args are statically invalid."); int64_t inputRank = inType.getSizes().size(); llvm::SmallVector inputStrides(inputRank, 1); From 8b0bf2e2930cc4ef0c9e1212b31c2c4fad2d9141 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:38:51 -0400 Subject: [PATCH 13/19] Bump LLVM to llvm/llvm-project@6c64c8a6f3f7 (#3818) - bumps llvm-project to https://github.com/llvm/llvm-project/commit/6c64c8a6f3f77c30745c751d4163ff6bf2fc323b - bumps stablehlo to https://github.com/openxla/stablehlo/commit/6e403b1aa6a71f5eaa09cc720e4ad42f692745e6 - Updates type conversion materialization functions to return Value after API change in llvm-project. --------- Signed-off-by: Max Dawkins --- externals/llvm-project | 2 +- externals/stablehlo | 2 +- .../Transforms/BackendTypeConversion.cpp | 86 +++++++++---------- 3 files changed, 45 insertions(+), 45 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index f0b3b6d15b2c..6c64c8a6f3f7 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit f0b3b6d15b2c0ee2cff2dd31dc075adb5d9a4ff7 +Subproject commit 6c64c8a6f3f77c30745c751d4163ff6bf2fc323b diff --git a/externals/stablehlo b/externals/stablehlo index d40285ef3db0..6e403b1aa6a7 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit d40285ef3db0687e3f1e2bb0d716d748485a9739 +Subproject commit 6e403b1aa6a71f5eaa09cc720e4ad42f692745e6 diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 0f2533e063f0..53de48f21934 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -57,16 +57,16 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::BoolType type) -> std::optional { return IntegerType::get(type.getContext(), 1); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 1 && type.isSignless())) - return std::nullopt; - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 1 && type.isSignless())) + return Value(); + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,19 +83,19 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::IntType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!isa(inputs[0].getType())) - return std::nullopt; - assert(inputs.size() == 1); - return builder.createOrFold(loc, inputs[0]); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return Value(); + // Other input type to be converted to i64 are handled by other + // materializers. + if (!isa(inputs[0].getType())) + return Value(); + assert(inputs.size() == 1); + return builder.createOrFold(loc, inputs[0]); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -112,13 +112,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::FloatType type) -> std::optional { return Float64Type::get(type.getContext()); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, Float64Type type, ValueRange inputs, - Location loc) -> std::optional { - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + Float64Type type, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -137,19 +137,19 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, [](Torch::GeneratorType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!isa(inputs[0].getType())) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return Value(); + // Other input type to be converted to i64 are handled by other + // materializers. + if (!isa(inputs[0].getType())) + return Value(); + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); From 8ea73b7b5376e7f21260ad66db33ca4fa1241118 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 28 Jan 2025 00:12:54 +0100 Subject: [PATCH 14/19] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 97e789cd8203..acf5290dbc6c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -4028,8 +4028,8 @@ FX_IMPORTER_TOSA_XFAIL_SET |= { "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseLogSigmoidModule_basic", - "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "RsubInt0d_NumToTensor_Module_basic", From 79d99ffbda130b912c1360f4c139b1b55fbe34f6 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 28 Jan 2025 08:18:06 +0100 Subject: [PATCH 15/19] Bump llvm --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index ca3473c82c82..41d02533ef16 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit ca3473c82c825f4a14238b3a9dec755a02338da4 +Subproject commit 41d02533ef16c5671972000ac69053f5305199bd From 6c213d7155035de9990a2c8aeb71721c7b308757 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 28 Jan 2025 08:23:50 +0100 Subject: [PATCH 16/19] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index acf5290dbc6c..bd4fc5737496 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3992,6 +3992,12 @@ "EinsumStaticModule_basic", "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", "EinsumStaticWithEllipsisSlicingModule_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "GridSamplerBasic1_basic", "GridSamplerBasic2_basic", "GridSamplerBasic3_basic", From 8e6a9e078c828e5d9b275fb979c17e7d87677526 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 28 Jan 2025 09:00:16 +0100 Subject: [PATCH 17/19] xfail --- projects/pt1/e2e_testing/xfail_sets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bd4fc5737496..2821a92549fb 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -522,6 +522,8 @@ "ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic", @@ -555,10 +557,6 @@ FX_IMPORTER_XFAIL_SET |= { "AtenSubFloatModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "EqIntModule_basic", "GeFloatModule_basic", "GtIntModule_basic", @@ -4034,6 +4032,8 @@ FX_IMPORTER_TOSA_XFAIL_SET |= { "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseLogSigmoidModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", "NumToTensorFloatModule_basic", From d55d7b9288700ddc7ef2894537724b49e2e81153 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 29 Jan 2025 11:22:54 +0100 Subject: [PATCH 18/19] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 93e911df3d12..bac2e1966f55 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3682,7 +3682,6 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", - "IndexSelectRank0IdxModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", @@ -4012,7 +4011,6 @@ "GridSamplerBasic2_basic", "GridSamplerBasic3_basic", "GridSamplerBasic4_basic", - "IndexSelectRank0IdxModule_basic", "IouOfModule_basic", "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic", From 0038abc7ff0b800e7e656d471511ff62887c194a Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 29 Jan 2025 14:10:16 +0100 Subject: [PATCH 19/19] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bac2e1966f55..e9784a52fa85 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -506,10 +506,7 @@ "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "OneHotModule_basic", # RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", "BernoulliFloatModule_basic", @@ -525,8 +522,6 @@ "ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", "SplitTensorGetItem_Module_basic", @@ -555,7 +550,6 @@ "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "OneHotModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", } @@ -3627,6 +3621,8 @@ "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic",