From f00339e86aa8c521aa2ebca20a067da5f2b17d67 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Tue, 25 Nov 2025 10:21:42 -0500 Subject: [PATCH] [tosa] : Add support for quantize_per_tensor. --- .../TorchToTosa/TosaLegalizeCommon.h | 25 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 264 ++++++++---------- .../TorchToTosa/TosaLegalizeCommon.cpp | 153 ++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 4 - test/Conversion/TorchToTosa/quantization.mlir | 34 +++ 5 files changed, 331 insertions(+), 149 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 4041e522fca1..a79aee4f077f 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -111,6 +111,31 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims); +// Creates IntegerAttrs for clamping, using provided min/max values or the +// numeric limits of the element type if the values are not provided. +LogicalResult getIntegerClampAttrs(ConversionPatternRewriter &rewriter, + Operation *op, Type elemTy, + std::optional minInt, + std::optional maxInt, + IntegerAttr &minAttr, IntegerAttr &maxAttr); + +// Creates FloatAttrs for clamping, using provided min/max values or the numeric +// limits of the element type if the values are not provided. +LogicalResult getFloatClampAttrs(ConversionPatternRewriter &rewriter, + Operation *op, Type elemTy, + std::optional minFloat, + std::optional maxFloat, + FloatAttr &minAttr, FloatAttr &maxAttr); + +// Implements "round half to even" logic for aten.round using TOSA ops. +// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)): +// res = floor(input) +// else: +// res = ceil(input) +std::optional createRoundHalfToEven(ConversionPatternRewriter &rewriter, + Operation *op, Value input, + RankedTensorType resultTy); + } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 62fc212b23be..cf1365fb98d4 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5304,69 +5304,45 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dyn_cast(getTypeConverter()->convertType(op.getType())); auto outElemTy = outType.getElementType(); - int64_t minInt, maxInt; - double minFloat, maxFloat; - bool isMinNotNone = false; - bool isMaxNotNone = false; - - auto isMinInt = matchPattern(op.getMin(), m_TorchConstantInt(&minInt)); - auto isMinFloat = matchPattern(op.getMin(), m_TorchConstantFloat(&minFloat)); - if (isMinInt) { - minFloat = static_cast(minInt); - isMinNotNone = true; - } else if (isMinFloat) { - minInt = static_cast(minFloat); - isMinNotNone = true; - } else { - if (succeeded(checkNotNone(rewriter, op, op.getMin()))) + std::optional minInt; + std::optional minFloat; + { + int64_t minIntVal; + double minFloatVal; + if (matchPattern(op.getMin(), m_TorchConstantInt(&minIntVal))) { + minInt = minIntVal; + minFloat = static_cast(minIntVal); + } else if (matchPattern(op.getMin(), m_TorchConstantFloat(&minFloatVal))) { + minFloat = minFloatVal; + minInt = static_cast(minFloatVal); + } else if (succeeded(checkNotNone(rewriter, op, op.getMin()))) { return rewriter.notifyMatchFailure(op, "min attr should be a torch constant"); + } } - auto isMaxInt = matchPattern(op.getMax(), m_TorchConstantInt(&maxInt)); - auto isMaxFloat = matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloat)); - if (isMaxInt) { - maxFloat = static_cast(maxInt); - isMaxNotNone = true; - } else if (isMaxFloat) { - maxInt = static_cast(maxFloat); - isMaxNotNone = true; - } else { - if (succeeded(checkNotNone(rewriter, op, op.getMax()))) + std::optional maxInt; + std::optional maxFloat; + { + int64_t maxIntVal; + double maxFloatVal; + if (matchPattern(op.getMax(), m_TorchConstantInt(&maxIntVal))) { + maxInt = maxIntVal; + maxFloat = static_cast(maxIntVal); + } else if (matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloatVal))) { + maxFloat = maxFloatVal; + maxInt = static_cast(maxFloatVal); + } else if (succeeded(checkNotNone(rewriter, op, op.getMax()))) { return rewriter.notifyMatchFailure(op, "max attr should be a torch constant"); + } } if (!isa(outElemTy)) { IntegerAttr minIntAttr, maxIntAttr; - if (outElemTy.isInteger(8)) { - minIntAttr = rewriter.getIntegerAttr( - outElemTy, - isMinNotNone ? minInt : std::numeric_limits::min()); - maxIntAttr = rewriter.getIntegerAttr( - outElemTy, - isMaxNotNone ? maxInt : std::numeric_limits::max()); - } else if (outElemTy.isInteger(16)) { - minIntAttr = rewriter.getIntegerAttr( - outElemTy, - isMinNotNone ? minInt : std::numeric_limits::min()); - maxIntAttr = rewriter.getIntegerAttr( - outElemTy, - isMaxNotNone ? maxInt : std::numeric_limits::max()); - } else if (outElemTy.isInteger(32)) { - minIntAttr = rewriter.getIntegerAttr( - outElemTy, - isMinNotNone ? minInt : std::numeric_limits::min()); - maxIntAttr = rewriter.getIntegerAttr( - outElemTy, - isMaxNotNone ? maxInt : std::numeric_limits::max()); - } else if (outElemTy.isInteger(64)) { - minIntAttr = rewriter.getI64IntegerAttr( - isMinNotNone ? minInt : std::numeric_limits::min()); - maxIntAttr = rewriter.getI64IntegerAttr( - isMaxNotNone ? maxInt : std::numeric_limits::max()); - } else { - return rewriter.notifyMatchFailure(op, "Unsupported integer type"); + if (failed(tosa::getIntegerClampAttrs(rewriter, op, outElemTy, minInt, + maxInt, minIntAttr, maxIntAttr))) { + return failure(); } rewriter.replaceOpWithNewOp( @@ -5376,28 +5352,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( tosa::NanPropagationMode::PROPAGATE)); } else { FloatAttr minFloatAttr, maxFloatAttr; - if (outElemTy.isF16()) { - minFloatAttr = - rewriter.getF16FloatAttr(isMinNotNone ? minFloat : Float16Lowest); - maxFloatAttr = - rewriter.getF16FloatAttr(isMaxNotNone ? maxFloat : Float16Max); - } else if (outElemTy.isBF16()) { - minFloatAttr = rewriter.getFloatAttr( - rewriter.getBF16Type(), isMinNotNone ? minFloat : BFloat16Lowest); - maxFloatAttr = rewriter.getFloatAttr( - rewriter.getBF16Type(), isMaxNotNone ? maxFloat : BFloat16Max); - } else if (outElemTy.isF32()) { - minFloatAttr = rewriter.getF32FloatAttr( - isMinNotNone ? minFloat : std::numeric_limits::lowest()); - maxFloatAttr = rewriter.getF32FloatAttr( - isMaxNotNone ? maxFloat : std::numeric_limits::max()); - } else if (outElemTy.isF64()) { - minFloatAttr = rewriter.getF64FloatAttr( - isMinNotNone ? minFloat : std::numeric_limits::lowest()); - maxFloatAttr = rewriter.getF64FloatAttr( - isMaxNotNone ? maxFloat : std::numeric_limits::max()); - } else { - return rewriter.notifyMatchFailure(op, "Unsupported floating-point type"); + if (failed(tosa::getFloatClampAttrs(rewriter, op, outElemTy, minFloat, + maxFloat, minFloatAttr, + maxFloatAttr))) { + return failure(); } rewriter.replaceOpWithNewOp( @@ -7308,17 +7266,6 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenRoundOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // To round to the nearest integer, we will consider the fractional part of - // the input element (= input element - integer part of element). If the - // fractional part is smaller than 0.5, round the number down. If the - // fractional part is 0.5, apply "round half to even" rule. If the fractional - // part is greater than 0.5, round up. - // - // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)): - // res = floor(input) - // else: - // res = ceil(input) - auto self = adaptor.getSelf(); auto selfTy = dyn_cast(self.getType()); @@ -7328,67 +7275,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultTy = cast(getTypeConverter()->convertType(op.getType())); - auto boolTy = - RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1)); - - auto resultElemTy = resultTy.getElementType(); - - auto oneHalf = - tosa::getConstTensor(rewriter, op, 0.5, {}, resultElemTy).value(); - - auto two = - tosa::getConstTensor(rewriter, op, 2, {}, resultElemTy).value(); - - if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf) - .failed() || - mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, two).failed()) + auto result = tosa::createRoundHalfToEven(rewriter, op, self, resultTy); + if (!result) { return rewriter.notifyMatchFailure( - op, "Failed to equalize ranks among operands and result"); - - auto floorInput = - tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, self); - - // input - floor(input) - auto fractionalPart = tosa::SubOp::create(rewriter, op->getLoc(), resultTy, - self, floorInput.getResult()); - - auto ceilInput = tosa::CeilOp::create(rewriter, op->getLoc(), resultTy, self); - - auto floorInputDivByTwo = tosa::createMulOpAndCast( - rewriter, op, resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); - - auto floorDivResult = tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, - floorInputDivByTwo.getResult()); - - // (floor(input) // 2) * 2 - auto evenComparison = tosa::createMulOpAndCast( - rewriter, op, resultTy, floorDivResult.getResult(), two, /*shift=*/0); - - // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 - auto floorInputEven = - tosa::EqualOp::create(rewriter, op->getLoc(), boolTy, - floorInput.getResult(), evenComparison.getResult()); - - auto fracEqualOneHalf = tosa::EqualOp::create( - rewriter, op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); - - auto fracLtOneHalf = tosa::GreaterOp::create( - rewriter, op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); - - // (frac == 0.5) && (floor(input) % 2 == 0) - auto fracEqualOneHalfCond = tosa::LogicalAndOp::create( - rewriter, op->getLoc(), boolTy, fracEqualOneHalf.getResult(), - floorInputEven.getResult()); - - // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0)) - auto floorResultCond = tosa::LogicalOrOp::create( - rewriter, op->getLoc(), boolTy, fracLtOneHalf.getResult(), - fracEqualOneHalfCond.getResult()); - - rewriter.replaceOpWithNewOp( - op, resultTy, floorResultCond.getResult(), floorInput.getResult(), - ceilInput.getResult()); + op, "failed to implement round-half-to-even with TOSA ops"); + } + rewriter.replaceOp(op, *result); return success(); } @@ -9339,6 +9232,86 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.quantize_per_tensor +// Implements +// Q = clamp(round(X / scale) + zero_point) +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenQuantizePerTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto loc = op->getLoc(); + + // Get scale and zero_point as constants. + double scaleConst; + if (!matchPattern(op.getScale(), m_TorchConstantFloat(&scaleConst))) + return rewriter.notifyMatchFailure(op, "scale must be a Scalar constant"); + + int64_t zpConst; + if (!matchPattern(op.getZeroPoint(), m_TorchConstantInt(&zpConst))) + return rewriter.notifyMatchFailure(op, + "zero point must be a Scalar constant"); + + // Get input and result types. + auto inputTy = cast(input.getType()); + auto inputElemTy = inputTy.getElementType(); + auto resultTy = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + auto resultElemTy = resultTy.getElementType(); + + // Rescale the input: input * (1.0 / scale) + auto scaleReciprocal = 1.0 / scaleConst; + auto scaleConstTensor = tosa::getConstTensor( + rewriter, op, scaleReciprocal, {}, inputElemTy) + .value(); + if (mlir::tosa::EqualizeRanks(rewriter, loc, input, scaleConstTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands"); + Value rescaledInput = tosa::createMulOpAndCast( + rewriter, op, inputTy, input, scaleConstTensor, /*shift =*/0); + + // Round + auto rounded = + tosa::createRoundHalfToEven(rewriter, op, rescaledInput, inputTy); + if (!rounded) { + return rewriter.notifyMatchFailure( + op, "failed to implement round-half-to-even with TOSA ops"); + } + + // Cast to the destination integer type. + auto intermediateIntTy = resultTy.clone(resultElemTy); + Value castToInt = + tosa::CastOp::create(rewriter, loc, intermediateIntTy, *rounded); + + // Add the zero point. + Value zpTensor = + tosa::createZeroPointTensor(rewriter, loc, intermediateIntTy, zpConst) + .value(); + if (mlir::tosa::EqualizeRanks(rewriter, loc, castToInt, zpTensor).failed()) + return failure(); + Value withZp = tosa::AddOp::create(rewriter, loc, intermediateIntTy, + castToInt, zpTensor); + + // Clamp the result to the valid range of the quantized type. + std::optional minInt, + maxInt; // no initialization needed as we want to clamp to the numeric + // limits of the type + IntegerAttr minIntAttr, maxIntAttr; + if (failed(tosa::getIntegerClampAttrs(rewriter, op, resultElemTy, minInt, + maxInt, minIntAttr, maxIntAttr))) { + return failure(); + } + Value clamped = tosa::ClampOp::create( + rewriter, loc, resultTy, withZp, minIntAttr, maxIntAttr, + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); + + rewriter.replaceOp(op, clamped); + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -9713,6 +9686,7 @@ std::set populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_ATENOP_PATTERN(AtenTanOp); INSERT_ATENOP_PATTERN(AtenUnfoldOp); INSERT_ATENOP_PATTERN(AtenDequantizeTensorOp); + INSERT_ATENOP_PATTERN(AtenQuantizePerTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 036f0f2e5110..d94f1b48a804 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -1112,5 +1112,158 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, .getResult(); } +LogicalResult getIntegerClampAttrs(ConversionPatternRewriter &rewriter, + Operation *op, Type elemTy, + std::optional minInt, + std::optional maxInt, + IntegerAttr &minAttr, IntegerAttr &maxAttr) { + + if (!elemTy.isInteger()) { + return rewriter.notifyMatchFailure( + op, "getIntegerClampAttrs expects integer type"); + } + + int64_t finalMin, finalMax; + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + + switch (bitwidth) { + case 8: + finalMin = minInt.value_or(std::numeric_limits::min()); + finalMax = maxInt.value_or(std::numeric_limits::max()); + break; + case 16: + finalMin = minInt.value_or(std::numeric_limits::min()); + finalMax = maxInt.value_or(std::numeric_limits::max()); + break; + case 32: + finalMin = minInt.value_or(std::numeric_limits::min()); + finalMax = maxInt.value_or(std::numeric_limits::max()); + break; + case 64: + finalMin = minInt.value_or(std::numeric_limits::min()); + finalMax = maxInt.value_or(std::numeric_limits::max()); + break; + default: + return rewriter.notifyMatchFailure( + op, "Unsupported integer bitwidth for clamp"); + } + + minAttr = rewriter.getIntegerAttr(elemTy, finalMin); + maxAttr = rewriter.getIntegerAttr(elemTy, finalMax); + + return success(); +} + +LogicalResult getFloatClampAttrs(ConversionPatternRewriter &rewriter, + Operation *op, Type elemTy, + std::optional minFloat, + std::optional maxFloat, + FloatAttr &minAttr, FloatAttr &maxAttr) { + + if (elemTy.isF16()) { + minAttr = rewriter.getF16FloatAttr( + minFloat.value_or(torch::Torch::Float16Lowest)); + maxAttr = + rewriter.getF16FloatAttr(maxFloat.value_or(torch::Torch::Float16Max)); + } else if (elemTy.isBF16()) { + minAttr = rewriter.getFloatAttr( + elemTy, minFloat.value_or(torch::Torch::BFloat16Lowest)); + maxAttr = rewriter.getFloatAttr( + elemTy, maxFloat.value_or(torch::Torch::BFloat16Max)); + } else if (elemTy.isF32()) { + minAttr = rewriter.getF32FloatAttr( + minFloat.value_or(std::numeric_limits::lowest())); + maxAttr = rewriter.getF32FloatAttr( + maxFloat.value_or(std::numeric_limits::max())); + } else if (elemTy.isF64()) { + minAttr = rewriter.getF64FloatAttr( + minFloat.value_or(std::numeric_limits::lowest())); + maxAttr = rewriter.getF64FloatAttr( + maxFloat.value_or(std::numeric_limits::max())); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported floating-point type for clamp"); + } + + return success(); +} + +std::optional createRoundHalfToEven(ConversionPatternRewriter &rewriter, + Operation *op, Value input, + RankedTensorType resultTy) { + // To round to the nearest integer, we will consider the fractional part of + // the input element (= input element - integer part of element). If the + // fractional part is smaller than 0.5, round the number down. If the + // fractional part is 0.5, apply "round half to even" rule. If the fractional + // part is greater than 0.5, round up. + // + // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)): + // res = floor(input) + // else: + // res = ceil(input) + auto boolTy = + RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1)); + auto resultElemTy = resultTy.getElementType(); + + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, resultElemTy).value(); + auto two = + tosa::getConstTensor(rewriter, op, 2, {}, resultElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, oneHalf) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, two).failed()) { + op->emitError("Failed to equalize ranks among operands and result"); + return std::nullopt; + } + + auto floorInput = + tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, input); + + // input - floor(input) + auto fractionalPart = tosa::SubOp::create(rewriter, op->getLoc(), resultTy, + input, floorInput.getResult()); + + auto ceilInput = + tosa::CeilOp::create(rewriter, op->getLoc(), resultTy, input); + + auto floorInputDivByTwo = tosa::createMulOpAndCast( + rewriter, op, resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + + auto floorDivResult = tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, + floorInputDivByTwo.getResult()); + + // (floor(input) // 2) * 2 + auto evenComparison = tosa::createMulOpAndCast( + rewriter, op, resultTy, floorDivResult.getResult(), two, /*shift=*/0); + + // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 + auto floorInputEven = + tosa::EqualOp::create(rewriter, op->getLoc(), boolTy, + floorInput.getResult(), evenComparison.getResult()); + + auto fracEqualOneHalf = tosa::EqualOp::create( + rewriter, op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); + + auto fracLtOneHalf = tosa::GreaterOp::create( + rewriter, op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); + + // (frac == 0.5) && (floor(input) % 2 == 0) + auto fracEqualOneHalfCond = tosa::LogicalAndOp::create( + rewriter, op->getLoc(), boolTy, fracEqualOneHalf.getResult(), + floorInputEven.getResult()); + + // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0)) + auto floorResultCond = tosa::LogicalOrOp::create( + rewriter, op->getLoc(), boolTy, fracLtOneHalf.getResult(), + fracEqualOneHalfCond.getResult()); + + auto selectOp = tosa::SelectOp::create( + rewriter, op->getLoc(), resultTy, floorResultCond.getResult(), + floorInput.getResult(), ceilInput.getResult()); + + return selectOp.getResult(); +} + } // namespace tosa } // namespace mlir diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4258547b94ee..93506ce17c0d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3679,10 +3679,7 @@ "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", "Conv1dGroupModule_basic", - "Conv2dQInt8Module_basic", - "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", - "Conv2dQInt8Module_not_depthwise", "Conv2dQInt8PerChannelModule_basic", "Conv2dQInt8PerChannelModule_depthwise", "Conv2dQInt8PerChannelModule_grouped", @@ -3692,7 +3689,6 @@ "Conv3dWithSamePaddingModule_basic", "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", - "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", diff --git a/test/Conversion/TorchToTosa/quantization.mlir b/test/Conversion/TorchToTosa/quantization.mlir index 74eef9a496d1..531b62fb9985 100644 --- a/test/Conversion/TorchToTosa/quantization.mlir +++ b/test/Conversion/TorchToTosa/quantization.mlir @@ -42,3 +42,37 @@ func.func @AtenMmQint8(%arg0: !torch.vtensor<[3,4],si8>, %arg1: !torch.vtensor<[ %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[3,3],!torch.qint32> -> !torch.vtensor<[3,3],f32> return %7 : !torch.vtensor<[3,3],f32> } + +// ----- +// CHECK-LABEL: func.func @quantization_per_tensor( +// CHECK-SAME: %[[IN:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[2,4,4],!torch.qint8> { +// CHECK: %[[ZP:.*]] = "tosa.const"() <{values = dense<3> : tensor<1x1x1xi8>}> : () -> tensor<1x1x1xi8> +// CHECK: %[[C2:.*]] = "tosa.const"() <{values = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: %[[CHALF:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: %[[C10:.*]] = "tosa.const"() <{values = dense<1.000000e+01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[IN]] : !torch.vtensor<[2,4,4],f32> -> tensor<2x4x4xf32> +// CHECK: %[[RESCALE:.*]] = tosa.mul %[[IN_TENSOR]], %[[C10]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32> +// CHECK: %[[FLOOR:.*]] = tosa.floor %[[RESCALE]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32> +// CHECK: %[[FRAC:.*]] = tosa.sub %[[RESCALE]], %[[FLOOR]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32> +// CHECK: %[[CEIL:.*]] = tosa.ceil %[[RESCALE]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32> +// CHECK: %[[FLOOR_DIV_BY_2:.*]] = tosa.mul %[[FLOOR]], %[[CHALF]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32> +// CHECK: %[[FLOOR_DIV:.*]] = tosa.floor %[[FLOOR_DIV_BY_2]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xf32> +// CHECK: %[[EVEN_COMP:.*]] = tosa.mul %[[FLOOR_DIV]], %[[C2]], %[[MUL_SHIFT]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<2x4x4xf32> +// CHECK: %[[FLOOR_INPUT_EVEN:.*]] = tosa.equal %[[FLOOR]], %[[EVEN_COMP]] : (tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1> +// CHECK: %[[FRAC_EQ_HALF:.*]] = tosa.equal %[[FRAC]], %[[CHALF]] : (tensor<2x4x4xf32>, tensor<1x1x1xf32>) -> tensor<2x4x4xi1> +// CHECK: %[[GRTR:.*]] = tosa.greater %[[CHALF]], %[[FRAC]] : (tensor<1x1x1xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xi1> +// CHECK: %[[AND:.*]] = tosa.logical_and %[[FRAC_EQ_HALF]], %[[FLOOR_INPUT_EVEN]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1> +// CHECK: %[[OR:.*]] = tosa.logical_or %[[GRTR]], %[[AND]] : (tensor<2x4x4xi1>, tensor<2x4x4xi1>) -> tensor<2x4x4xi1> +// CHECK: %[[SELECT:.*]] = tosa.select %[[OR]], %[[FLOOR]], %[[CEIL]] : (tensor<2x4x4xi1>, tensor<2x4x4xf32>, tensor<2x4x4xf32>) -> tensor<2x4x4xf32> +// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<2x4x4xf32>) -> tensor<2x4x4xi8> +// CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[ZP]] : (tensor<2x4x4xi8>, tensor<1x1x1xi8>) -> tensor<2x4x4xi8> +// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x4x4xi8> -> !torch.vtensor<[2,4,4],!torch.qint8> +// CHECK: return %[[RES]] +func.func @quantization_per_tensor(%arg0: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[2,4,4],!torch.qint8> { + %dtype = torch.constant.int 12 + %scale = torch.constant.float 0.1 + %zp = torch.constant.int 3 + %0 = torch.aten.quantize_per_tensor %arg0, %scale, %zp, %dtype : !torch.vtensor<[2,4,4],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],!torch.qint8> + return %0 : !torch.vtensor<[2,4,4],!torch.qint8> +}