From f8158540e93392b223db2b3e0df261f011f60dd0 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Fri, 12 Sep 2025 20:24:55 +0000 Subject: [PATCH] [TOSA] Update TOSA's rounding mode, nan propagation and resize mode to enums Context: TOSA's rounding mode, nan propagation, and resize mode have been updated from string attributes to enums. This commit updates the Torch to TOSA path to align with those changes. Signed-off-by: Justin Ngo --- .../TorchToTosa/TosaLegalizeUtils.h | 5 +- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 69 +++++++++++++------ .../TorchToTosa/TosaLegalizeCommon.cpp | 7 +- .../TorchToTosa/TosaLegalizeUtils.cpp | 17 +++-- test/Conversion/TorchToTosa/basic.mlir | 4 +- 5 files changed, 69 insertions(+), 33 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index be1ea0c3221a..14df4928681a 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -11,6 +11,7 @@ #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -26,8 +27,8 @@ namespace tosa { // rounding mode Value buildRescale(PatternRewriter &rewriter, Operation *op, ShapedType output_type, Value input_val, double scale, - int64_t input_zp, int64_t output_zp, StringRef rounding_mode, - bool scale32); + int64_t input_zp, int64_t output_zp, + tosa::RoundingMode rounding_mode, bool scale32); // Creates TOSA rescale op with int32 output Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..f4b487dc3115 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -138,7 +138,9 @@ class ConvertAtenBinaryOp : public OpConversionPattern { // tosa.minimum binaryOp = rewriter.create( op->getLoc(), outTy, lhs, rhs, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { binaryOp = tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); @@ -907,7 +909,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp rewriter.replaceOpWithNewOp( op, outTy, self, minFloatAttr, maxFloatAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); return success(); } @@ -1237,7 +1241,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(outputReduceTy), input, reduceDimAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) .getResult(); }; @@ -3925,7 +3931,9 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), selfElemType), - self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + self, dimAttr, /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { reduceOp = rewriter.create( op->getLoc(), @@ -3946,14 +3954,18 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - negateOp, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + negateOp, dimAttr, /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax argMaxOp = rewriter.create( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + self, dimAttr, /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } if (argMaxOp.getType() != indicesType) { @@ -5202,7 +5214,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); } else { FloatAttr minFloatAttr, maxFloatAttr; if (outElemTy.isF16()) { @@ -5231,7 +5245,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); } return success(); @@ -5340,13 +5356,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum auto minThresholdCheck = rewriter.create( op->getLoc(), resultType, self, min, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); // yi = min(max(xi, min_valuei), max_valuei) // Use default NaN Propagation mode "PROPAGATE" for tosa.minimum auto result = rewriter.create( op->getLoc(), resultType, minThresholdCheck, max, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); rewriter.replaceOp(op, result); return success(); @@ -5934,7 +5954,10 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { pooledOutput = rewriter .create( op->getLoc(), outputTy, input, kernel, stride, pad, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)) .getResult(); } else if constexpr (std::is_same::value) { TypeAttr accType; @@ -6825,11 +6848,11 @@ ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only nearest and bilinear interpolation modes supported"); - std::string mode; + tosa::ResizeMode mode; if (pyMode == "bilinear") { - mode = "BILINEAR"; + mode = tosa::ResizeMode::BILINEAR; } else { - mode = "NEAREST_NEIGHBOR"; + mode = tosa::ResizeMode::NEAREST_NEIGHBOR; } bool alignCorners; @@ -6891,7 +6914,7 @@ ConvertAtenOp::matchAndRewrite( offset = 0; // If nearest neighbours we need to guarantee we round up. - if (mode == "NEAREST_NEIGHBOR" && alignCorners) { + if (mode == tosa::ResizeMode::NEAREST_NEIGHBOR && alignCorners) { offset += n / 2; } @@ -6911,7 +6934,8 @@ ConvertAtenOp::matchAndRewrite( tosa::getTosaConstShape(rewriter, op->getLoc(), {offset_y, offset_x}); auto border = tosa::getTosaConstShape(rewriter, op->getLoc(), {border_y, border_x}); - StringAttr modeAttr = rewriter.getStringAttr(mode); + + auto modeAttr = tosa::ResizeModeAttr::get(rewriter.getContext(), mode); auto resizeOpResult = rewriter @@ -8605,11 +8629,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Clamp input to [eps, 1 - eps] when eps is not None // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp if (!isEpsNone) { - zi = rewriter - .create( - op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) - .getResult(); + zi = + rewriter + .create( + op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr, + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) + .getResult(); } auto one = diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 02d1390ed148..036f0f2e5110 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -764,7 +765,9 @@ std::optional convertReduceOpCommon( // and tosa.reduce_max reduce_op = CreateOpAndInfer( rewriter, op->getLoc(), reduce_type, val, axis_attr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, val, axis_attr); @@ -777,7 +780,7 @@ std::optional convertReduceOpCommon( RankedTensorType output_rescale_type = RankedTensorType::get(shape_vec, output_type.getElementType()); val = buildRescale(rewriter, op, output_rescale_type, val, output_scale, - 0, output_zp, "SINGLE_ROUND", true); + 0, output_zp, tosa::RoundingMode::SINGLE_ROUND, true); } // Optionally squeeze out the reduced axes. diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index bd902d8e2575..3fc11f4fa13f 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -35,8 +35,8 @@ Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter, // rounding mode Value buildRescale(PatternRewriter &rewriter, Operation *op, ShapedType output_type, Value input_val, double scale, - int64_t input_zp, int64_t output_zp, StringRef rounding_mode, - bool scale32) { + int64_t input_zp, int64_t output_zp, + tosa::RoundingMode rounding_mode, bool scale32) { int32_t multiplier; int32_t shift; @@ -70,7 +70,8 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, input_val, multiplier_val, shift_val, input_zp_val.value(), output_zp_val.value(), - rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode), + rewriter.getBoolAttr(scale32), + tosa::RoundingModeAttr::get(rewriter.getContext(), rounding_mode), rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned), rewriter.getBoolAttr(output_unsigned)); @@ -87,7 +88,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, auto output_type = input_type.clone(rewriter.getI32Type()); return buildRescale(rewriter, op, output_type, input_val, input_scale, - input_zp, 0, "SINGLE_ROUND", true); + input_zp, 0, tosa::RoundingMode::SINGLE_ROUND, true); } // Creates a TOSA rescale op based on conv2d parameters. @@ -146,7 +147,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, multiplier_val, shift_val, input_zp_val.value(), output_zp_val.value(), - rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"), + rewriter.getBoolAttr(scale32), + tosa::RoundingModeAttr::get(rewriter.getContext(), + tosa::RoundingMode::DOUBLE_ROUND), rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned), rewriter.getBoolAttr(output_unsigned)); @@ -188,7 +191,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, multiplier_val, shift_val, input_zp_val.value(), output_zp_val.value(), - rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"), + rewriter.getBoolAttr(scale32), + tosa::RoundingModeAttr::get(rewriter.getContext(), + tosa::RoundingMode::DOUBLE_ROUND), rewriter.getBoolAttr(true), rewriter.getBoolAttr(input_unsigned), rewriter.getBoolAttr(output_unsigned)); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index cb1a69e6a622..b0ee63852478 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1559,7 +1559,7 @@ func.func @torch.aten.isclose$basic(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !to // CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "BILINEAR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = BILINEAR} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32> // CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> // CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32> @@ -1588,7 +1588,7 @@ func.func @torch.aten.__interpolate.size_list_scale_list.bilinear(%arg0: !torch. // CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "NEAREST_NEIGHBOR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = NEAREST_NEIGHBOR} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32> // CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> // CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32>