diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a5336ed6bf2cd..00df14b1bdb77 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1392,6 +1392,137 @@ class PointwiseConverter : public OpConversionPattern { } }; +// Collapse tensor<1xiN> into tensor +// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor +static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input, + Location loc) { + SmallVector reassociation; + // Create the collapsed type + auto inputType = cast(input.getType()); + auto elemType = inputType.getElementType(); + auto collapsedType = RankedTensorType::get({}, elemType); + // Emit the collapse op + return rewriter.create(loc, collapsedType, input, + reassociation); +} + +static llvm::SmallVector +convertToI8(const llvm::SmallVector &input) { + llvm::SmallVector output; + output.reserve(input.size()); + + for (auto v : llvm::map_range( + input, [](int32_t val) { return static_cast(val); })) { + output.push_back(v); + } + return output; +} + +// The shift or multiplier may be either constant or non-constant, depending on +// whether dynamic extension is enabled. +// - If the shift or multiplier is non-constant, add it as an input to +// linalg::GenericOp by: +// 1. Pushing it into 'genericInputs'. +// 2. Appending a corresponding affine map to 'indexingMaps'. +// - If the shift or multiplier is constant, set 'constant' instead. +static void setupLinalgGenericOpInputAndIndexingMap( + PatternRewriter &rewriter, llvm::SmallVector &values, + SmallVector &genericInputs, SmallVector &indexingMaps, + bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg, + bool isShift = false) { + + auto loc = op.getLoc(); + auto inputTy = cast(op.getInput().getType()); + unsigned rank = inputTy.getRank(); + SmallVector exprs = {rewriter.getAffineDimExpr(rank - 1)}; + + if (isConstant) { + // If we are rescaling per-channel then we need to store the + // values in a buffer. + if (values.size() == 1) { + IntegerAttr intAttr = isShift + ? rewriter.getI8IntegerAttr(values.front()) + : rewriter.getI32IntegerAttr(values.front()); + constant = rewriter.create(loc, intAttr); + } else { + auto elementType = + isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type(); + auto tensorType = RankedTensorType::get( + {static_cast(values.size())}, elementType); + DenseIntElementsAttr EltAttr; + if (isShift) + EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values)); + else + EltAttr = DenseIntElementsAttr::get(tensorType, values); + genericInputs.push_back( + arith::ConstantOp::create(rewriter, loc, EltAttr)); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } else { + // If we are not rescaling per-channel then we need to collapse 1xN to N + // and push broadcastMap. + auto operand = isShift ? op.getShift() : op.getMultiplier(); + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && tensorType.hasStaticShape() && + tensorType.getShape()[0] == 1) { + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = + AffineMap::get(rank, 0, {}, rewriter.getContext()); + genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc)); + indexingMaps.push_back(broadcastMap); + } else { + genericInputs.push_back(operand); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } + arg = indexingMaps.size() - 1; +} + +// Return the extended Zp to be used in subsequent arithmetic operations. +static Value getExtendZp(OpBuilder &builder, Type valueTy, + FailureOr maybeZp, Location loc, + ValueRange blockArgs, int64_t zpArg, + bool isOutputZp = false) { + Value result; + const int32_t bitwidth = valueTy.getIntOrFloatBitWidth(); + const uint32_t attrBitwidth = + isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32); + auto extendType = builder.getIntegerType(attrBitwidth); + // The Zp value can be either constant or non-constant, depending on + // whether dynamic extension is enabled. + // If 'maybeZp' fails, it indicates that Zp is non-constant and will + // be passed as an input to linalg::GenericOp. + if (failed(maybeZp)) { + result = blockArgs[zpArg]; + auto zpTy = result.getType(); + if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) { + // For ExtUIOp, the input must be signless. + // UnrealizedConversionCastOp will cast the input to signless type. + if (zpTy.isUnsignedInteger()) { + result = + UnrealizedConversionCastOp::create( + builder, loc, + builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result) + .getResult(0); + } + if (zpTy.isUnsignedInteger()) { + return builder.create(loc, extendType, result); + } else { + return builder.create(loc, extendType, result); + } + } + } else { + return builder.create( + loc, IntegerAttr::get(extendType, *maybeZp)); + } + return result; +} + class RescaleConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1423,40 +1554,46 @@ class RescaleConverter : public OpRewritePattern { } } - // The shift and multiplier values. DenseElementsAttr shiftElems; - if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant shift input values"); + bool isShiftConstant = false; + if (matchPattern(op.getShift(), m_Constant(&shiftElems))) + isShiftConstant = true; DenseElementsAttr multiplierElems; - if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant multiplier input values"); - - llvm::SmallVector shiftValues = - llvm::to_vector(shiftElems.getValues()); - // explicit cast is required here - llvm::SmallVector multiplierValues = llvm::to_vector( - llvm::map_range(multiplierElems.getValues(), - [](IntegerAttr attr) -> int32_t { - return static_cast(attr.getInt()); - })); - - // If we shift by more than the bitwidth, this just sets to 0. - for (int i = 0, s = multiplierValues.size(); i < s; i++) { - if (shiftValues[i] > 63) { - shiftValues[i] = 0; - multiplierValues[i] = 0; + bool isMultiplierConstant = false; + if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) + isMultiplierConstant = true; + + llvm::SmallVector shiftValues; + llvm::SmallVector multiplierValues; + bool doubleRound; + + if (isMultiplierConstant && isShiftConstant) { + // explicit cast is required here + shiftValues = llvm::to_vector(llvm::map_range( + shiftElems.getValues(), [](IntegerAttr attr) -> int32_t { + return static_cast(attr.getInt()); + })); + multiplierValues = llvm::to_vector( + llvm::map_range(multiplierElems.getValues(), + [](IntegerAttr attr) -> int32_t { + return static_cast(attr.getInt()); + })); + + // If we shift by more than the bitwidth, this just sets to 0. + for (int i = 0, s = multiplierValues.size(); i < s; i++) { + if (shiftValues[i] > 63) { + shiftValues[i] = 0; + multiplierValues[i] = 0; + } } - } + // Double round only occurs if shift is greater than 31, check that this + // is ever true. + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && + llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); + } else + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND; - // Double round only occurs if shift is greater than 31, check that this - // is ever true. - - bool doubleRound = - op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && - llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); RoundingMode roundingMode = doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND; @@ -1468,45 +1605,43 @@ class RescaleConverter : public OpRewritePattern { // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; - if (multiplierValues.size() == 1) { - multiplierConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); - } else { - SmallVector multiplierExprs{ - rewriter.getAffineDimExpr(rank - 1)}; - auto multiplierType = - RankedTensorType::get({static_cast(multiplierValues.size())}, - rewriter.getI32Type()); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, - DenseIntElementsAttr::get(multiplierType, multiplierValues))); - - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, multiplierExprs, - rewriter.getContext())); - - multiplierArg = indexingMaps.size() - 1; - } + setupLinalgGenericOpInputAndIndexingMap( + rewriter, multiplierValues, genericInputs, indexingMaps, + isMultiplierConstant, op, multiplierConstant, multiplierArg); // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; - if (shiftValues.size() == 1) { - shiftConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); - } else { - SmallVector shiftExprs = { - rewriter.getAffineDimExpr(rank - 1)}; - auto shiftType = - RankedTensorType::get({static_cast(shiftValues.size())}, - rewriter.getIntegerType(8)); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, shiftExprs, - rewriter.getContext())); - shiftArg = indexingMaps.size() - 1; + setupLinalgGenericOpInputAndIndexingMap( + rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op, + shiftConstant, shiftArg, true); + + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext()); + FailureOr maybeIZp = op.getInputZeroPoint(); + FailureOr maybeOZp = op.getOutputZeroPoint(); + // The inputZp and outputZp may be either constant or non-constant, + // depending on whether dynamic extension is enabled. + // - If the zp's are non-constant, add them as an inputs to + // linalg::GenericOp by: + // 1. Pushing it into 'genericInputs'. + // 2. Appending a corresponding affine map to 'indexingMaps'. + // - If the zp's are constant, they would be generated as arith.constant. + int64_t iZpArg = 0; + if (failed(maybeIZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(3), loc)); + indexingMaps.push_back(broadcastMap); + iZpArg = indexingMaps.size() - 1; + } + int64_t oZpArg = 0; + if (failed(maybeOZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(4), loc)); + indexingMaps.push_back(broadcastMap); + oZpArg = indexingMaps.size() - 1; } // Indexing maps for output values. @@ -1526,36 +1661,17 @@ class RescaleConverter : public OpRewritePattern { Type valueTy = value.getType(); FailureOr maybeIZp = op.getInputZeroPoint(); - if (failed(maybeIZp)) { - (void)rewriter.notifyMatchFailure( - op, "input zero point cannot be statically determined"); - return; - } - - const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); - // Extend zeropoint for sub-32bits widths. - const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), - *maybeIZp)); + auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp, + nestedLoc, blockArgs, iZpArg); FailureOr maybeOZp = op.getOutputZeroPoint(); - if (failed(maybeOZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return; - }; + auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp, + nestedLoc, blockArgs, oZpArg, true); IntegerType outIntType = cast(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); - const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), - *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index a7a73ae904042..780c25a9445a0 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1538,6 +1538,92 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8> // ----- +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()> +// CHECK-LABEL: @rescale_no_const +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) { + // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor + // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor + // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor + // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor, tensor, tensor, tensor) outs([[INIT]] : tensor<2xi8>) { + // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8): + // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32 + // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32 + // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32 + // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32 + // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32 + // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32 + // CHECK: %c-128_i32 = arith.constant -128 : i32 + // CHECK: %c127_i32 = arith.constant 127 : i32 + // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32 + // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32 + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> + return %0 : tensor<2xi8> +} + +// ----- + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()> +// CHECK-LABEL: @rescale_no_const_per_channel +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]] +func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) { + // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor + // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor, tensor) outs([[INIT]] : tensor<2xi8>) { + // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8): + // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32 + // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32 + // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32 + // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32 + // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32 + // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32 + // CHECK: %c-128_i32 = arith.constant -128 : i32 + // CHECK: %c127_i32 = arith.constant 127 : i32 + // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32 + // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32 + %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> + return %0 : tensor<2xi8> +} + +// ----- + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()> +// CHECK-LABEL: @rescale_no_const_per_channel_input_output_zp_ui8 +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]] +func.func @rescale_no_const_per_channel_input_output_zp_ui8(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xui8>, %output_zp : tensor<1xui8>) -> (tensor<2xui8>) { + // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xui8> into tensor + // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xui8> into tensor + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xui8> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor, tensor) outs([[INIT]] : tensor<2xui8>) { + // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: ui8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8): + // CHECK: [[INPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG3]] : ui8 to i8 + // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32 + // CHECK: [[OUTPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8 + // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extui [[OUTPUT_ZP_I8]] : i8 to i32 + // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32 + // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32 + // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32 + // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32 + // CHECK: %c0_i32 = arith.constant 0 : i32 + // CHECK: %c255_i32 = arith.constant 255 : i32 + // CHECK: [[MAX:%.+]] = arith.maxsi %c0_i32, [[TMP3]] : i32 + // CHECK: [[MIN:%.+]] = arith.minsi %c255_i32, [[MAX]] : i32 + %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xui8>, tensor<1xui8>) -> tensor<2xui8> + return %0 : tensor<2xui8> +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @reverse