diff --git a/.gitmodules b/.gitmodules index 2ab93fe7..8facdfdd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,5 +2,5 @@ [submodule "llvm-project"] path = llvm-project - url = https://github.com/Xilinx/llvm-project.git + url = https://github.com/Xilinx/llvm-aie.git branch = feature/fused-ops diff --git a/lib/Conversion/TosaToXTenNN.cpp b/lib/Conversion/TosaToXTenNN.cpp index 13b47a11..d4c003c9 100644 --- a/lib/Conversion/TosaToXTenNN.cpp +++ b/lib/Conversion/TosaToXTenNN.cpp @@ -101,8 +101,10 @@ m_ConstantFloatLog2(IntegerAttr::ValueType *bindValue) { // NOLINT ///\return true shape does not change from input to output ///\return false shape does change from input to output bool sameInputAndOutputShape(mlir::Operation *operation) { - assert(operation->getNumOperands() == 2 && operation->getNumResults() == 1 && - "expected operation with 2 inputs and one output."); + assert( + (operation->getNumOperands() == 2 || operation->getNumOperands() == 3) && + operation->getNumResults() == 1 && + "expected operation with 2 inputs and one output."); return operation->getOperand(0).getType() == operation->getResult(0).getType(); } @@ -179,7 +181,8 @@ class FoldMulsToQDQOps : public OpRewritePattern { APInt quantizeShift(32, 0, true); auto isQDQPattern = m_Op( m_Op(m_Op( - matchers::m_Any(), m_ConstantFloatLog2(&quantizeShift)))); + matchers::m_Any(), m_ConstantFloatLog2(&quantizeShift), + matchers::m_Any()))); if (!dequantizeOp || !isQDQPattern.match(dequantizeOp)) { return rewriter.notifyMatchFailure(dequantizeMulOp->getLoc(), "expected mul->q->dq->mul pattern."); @@ -261,7 +264,8 @@ class MoveScalarTensorToRHSOfMul : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::MulOp mulOp, PatternRewriter &rewriter) const override { - if (!m_Op(m_Constant(), m_Constant()).match(mulOp)) { + if (!m_Op(m_Constant(), m_Constant(), m_Constant()) + .match(mulOp)) { return rewriter.notifyMatchFailure( mulOp.getLoc(), "only reorganize operands on muls with two constants"); diff --git a/lib/Conversion/XTenNNToTosa.cpp b/lib/Conversion/XTenNNToTosa.cpp index 3695050e..9e39318b 100644 --- a/lib/Conversion/XTenNNToTosa.cpp +++ b/lib/Conversion/XTenNNToTosa.cpp @@ -66,6 +66,15 @@ APFloat convertF32AttrToFloatTy(FloatAttr attr, Type typeToConvertTo) { llvm::RoundingMode::NearestTiesToEven, &losesInfo); return scale; } + +Value getZeroShift(PatternRewriter &rewriter, Location loc) { + auto shiftValueType = RankedTensorType::get({1}, rewriter.getI8Type()); + auto shiftValueAttr = DenseElementsAttr::get(shiftValueType, {int8_t(0)}); + + auto shiftValue = + rewriter.create(loc, shiftValueType, shiftValueAttr); + return shiftValue; +} } // namespace class QuantizeOp : public OpRewritePattern { @@ -96,7 +105,7 @@ class QuantizeOp : public OpRewritePattern { auto mulOp = rewriter.create( quantizeOp.getLoc(), inputType, quantizeOp->getOperand(0), - constOp->getResult(0), rewriter.getI8IntegerAttr(0)); + constOp->getResult(0), getZeroShift(rewriter, quantizeOp.getLoc())); mlir::Value castFrom = mulOp->getResult(0); if (!quantizeOp.getZeroPoint().isZero()) { @@ -189,7 +198,7 @@ class DequantizeOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( dequantizeOp, dequantizeOp->getResult(0).getType(), zeroPointSubOp->getResult(0), constOp->getResult(0), - rewriter.getI8IntegerAttr(0)); + getZeroShift(rewriter, dequantizeOp.getLoc())); return success(); } }; diff --git a/llvm-project b/llvm-project index e3c5c4db..a0fc10d3 160000 --- a/llvm-project +++ b/llvm-project @@ -1 +1 @@ -Subproject commit e3c5c4dbbb4e86186f4b666471738085782e1be7 +Subproject commit a0fc10d350b9b1b29767796f3faeb07132cd08fd diff --git a/test/Conversion/TosaToXTenNN/quantization.mlir b/test/Conversion/TosaToXTenNN/quantization.mlir index b32ebac5..d5e1a486 100644 --- a/test/Conversion/TosaToXTenNN/quantization.mlir +++ b/test/Conversion/TosaToXTenNN/quantization.mlir @@ -22,12 +22,14 @@ module attributes {} { %18 = "tosa.const"() {value = dense<0> : tensor<1x1x1x1xi8>} : () -> tensor<1x1x1x1xi8> %19 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x3x4x4xf32>} : () -> tensor<1x3x4x4xf32> %20 = "tosa.reciprocal"(%19) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> - %21 = "tosa.mul"(%arg0, %20) { shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + %shift_21 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %21 = "tosa.mul"(%arg0, %20, %shift_21) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %22 = "tosa.cast"(%21) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> %23 = "tosa.add"(%22, %18) {} : (tensor<1x3x4x4xi8>, tensor<1x1x1x1xi8>) -> tensor<1x3x4x4xi8> %24 = "tosa.sub"(%23, %18) {} : (tensor<1x3x4x4xi8>, tensor<1x1x1x1xi8>) -> tensor<1x3x4x4xi8> %25 = "tosa.cast"(%24) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %26 = "tosa.mul"(%25, %19) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + %shift_26 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %26 = "tosa.mul"(%25, %19, %shift_26) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> return %26 : tensor<1x3x4x4xf32> } } @@ -45,12 +47,14 @@ module attributes {} { %17 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %18 = "tosa.const"() {value = dense<0> : tensor<1x1x1x1xi8>} : () -> tensor<1x1x1x1xi8> %19 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %21 = "tosa.mul"(%arg0, %17) { shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift_21 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %21 = "tosa.mul"(%arg0, %17, %shift_21) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %22 = "tosa.cast"(%21) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> %23 = "tosa.add"(%22, %18) {} : (tensor<1x3x4x4xi8>, tensor<1x1x1x1xi8>) -> tensor<1x3x4x4xi8> %24 = "tosa.sub"(%23, %18) {} : (tensor<1x3x4x4xi8>, tensor<1x1x1x1xi8>) -> tensor<1x3x4x4xi8> %25 = "tosa.cast"(%24) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %26 = "tosa.mul"(%25, %19) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift_26 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %26 = "tosa.mul"(%25, %19, %shift_26) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> return %26 : tensor<1x3x4x4xf32> } } @@ -68,11 +72,13 @@ module attributes {} { %17 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %18 = "tosa.const"() {value = dense<0> : tensor<1x1x1x1xi8>} : () -> tensor<1x1x1x1xi8> %19 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %21 = "tosa.mul"(%arg0, %17) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift_21 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %21 = "tosa.mul"(%arg0, %17, %shift_21) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %22 = "tosa.cast"(%21) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> %24 = "tosa.sub"(%22, %18) {} : (tensor<1x3x4x4xi8>, tensor<1x1x1x1xi8>) -> tensor<1x3x4x4xi8> %25 = "tosa.cast"(%24) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %26 = "tosa.mul"(%25, %19) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift_26 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %26 = "tosa.mul"(%25, %19, %shift_26) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> return %26 : tensor<1x3x4x4xf32> } } @@ -89,10 +95,12 @@ module attributes {} { func.func @all_ops_folded(%arg0: tensor<1x3x4x4xf32> ) -> tensor<1x3x4x4xf32> { %17 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %19 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %21 = "tosa.mul"(%arg0, %17) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift_21 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %21 = "tosa.mul"(%arg0, %17, %shift_21) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %22 = "tosa.cast"(%21) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> %25 = "tosa.cast"(%22) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %26 = "tosa.mul"(%25, %19) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift_26 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %26 = "tosa.mul"(%25, %19, %shift_26) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> return %26 : tensor<1x3x4x4xf32> } } @@ -104,14 +112,16 @@ module attributes {} { // CHECK-LABEL: func.func @missing_dq_cast_mul( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> { // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> -// CHECK: return %[[VAL_3]] : tensor<1x3x4x4xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> +// CHECK: return %[[VAL_4]] : tensor<1x3x4x4xi8> // CHECK: } func.func @missing_dq_cast_mul(%arg0: tensor<1x3x4x4xf32> ) -> tensor<1x3x4x4xi8> { %17 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %19 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %21 = "tosa.mul"(%arg0, %17) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift_21 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %21 = "tosa.mul"(%arg0, %17, %shift_21) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %22 = "tosa.cast"(%21) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> return %22 : tensor<1x3x4x4xi8> } @@ -137,82 +147,93 @@ module attributes {} { } // -- - module attributes {} { -// CHECK-LABEL: func.func @mul_missing_dequantize( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return %[[VAL_2]] : tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @mul_missing_dequantize +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_2_]] : tensor<1x3x4x4xf32> // CHECK: } +// CHECK: } func.func @mul_missing_dequantize(%arg0: tensor<1x3x4x4xf32> ) -> tensor<1x3x4x4xf32> { %0 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %1 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - return %1 : tensor<1x3x4x4xf32> + %1 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = "tosa.mul"(%arg0, %0, %1) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + return %2 : tensor<1x3x4x4xf32> } } // -- module attributes {} { -// CHECK-LABEL: func.func @mul_missing_quantize( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_2:.*]] = xten_nn.dequantize(%[[VAL_0]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.mul %[[VAL_2]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return %[[VAL_3]] : tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @mul_missing_quantize +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = xten_nn.dequantize([[PARAM_0_]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_1_]], [[VAR_0_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_3_]] : tensor<1x3x4x4xf32> // CHECK: } func.func @mul_missing_quantize(%arg0: tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> { %0 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = xten_nn.dequantize(%arg0 : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> - %2 = "tosa.mul"(%1, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - return %2 : tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%1, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + return %3 : tensor<1x3x4x4xf32> } } // -- module attributes {} { -// CHECK-LABEL: func.func @unequal_mul_constants( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3.000000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_3:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_4:.*]] = xten_nn.quantize(%[[VAL_3]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> -// CHECK: %[[VAL_5:.*]] = xten_nn.dequantize(%[[VAL_4]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return %[[VAL_6]] : tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @unequal_mul_constants +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.000000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = xten_nn.quantize([[VAR_3_]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_5_:%.+]] = xten_nn.dequantize([[VAR_4_]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_6_]] : tensor<1x3x4x4xf32> +// CHECK: } // CHECK: } func.func @unequal_mul_constants(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.00000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> - %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - return %5 : tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%arg0, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + %4 = xten_nn.quantize(%3 : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> + %5 = xten_nn.dequantize(%4 : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> + %6 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %7 = "tosa.mul"(%5, %1, %6) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + return %7 : tensor<1x3x4x4xf32> } } // -- module attributes {} { -// CHECK-LABEL: func.func @equal_mul_constants_not_log2( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.000000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_3:.*]] = xten_nn.quantize(%[[VAL_2]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> -// CHECK: %[[VAL_4:.*]] = xten_nn.dequantize(%[[VAL_3]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return %[[VAL_5]] : tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @equal_mul_constants_not_log2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_3_:%.+]] = xten_nn.quantize([[VAR_2_]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_4_:%.+]] = xten_nn.dequantize([[VAR_3_]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_5_]] : tensor<1x3x4x4xf32> // CHECK: } func.func @equal_mul_constants_not_log2(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = "tosa.const"() {value = dense<3.00000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %1 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = "tosa.mul"(%arg0, %0, %1) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - return %5 : tensor<1x3x4x4xf32> + %5 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %6 = "tosa.mul"(%4, %0, %5) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + return %6 : tensor<1x3x4x4xf32> } } @@ -228,91 +249,96 @@ module attributes {} { func.func @sum_shifts(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {shift = 3 : si32, scale = 8.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xi8> - %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {shift = 3 : si32, scale = 8.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - return %5 : tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%arg0, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + %4 = xten_nn.quantize(%3 : tensor<1x3x4x4xf32>) {shift = 3 : si32, scale = 8.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xi8> + %5 = xten_nn.dequantize(%4 : tensor<1x3x4x4xi8>) {shift = 3 : si32, scale = 8.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xf32> + %6 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %7 = "tosa.mul"(%5, %1, %6) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + return %7 : tensor<1x3x4x4xf32> } } // -- - module attributes {} { // CHECK-LABEL: func.func @sum_shifts_no_shifts // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_3_:%.+]] = xten_nn.quantize([[VAR_2_]] : tensor<1x3x4x4xf32>) {scale = 7.000000e+00 : f32, zero_point = 0 : i8} -> tensor<1x3x4x4xi8> -// CHECK: [[VAR_4_:%.+]] = xten_nn.dequantize([[VAR_3_]] : tensor<1x3x4x4xi8>) {scale = 7.000000e+00 : f32, zero_point = 0 : i8} -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return [[VAR_5_]] : tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = xten_nn.quantize([[VAR_3_]] : tensor<1x3x4x4xf32>) {scale = 7.000000e+00 : f32, zero_point = 0 : i8} -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_5_:%.+]] = xten_nn.dequantize([[VAR_4_]] : tensor<1x3x4x4xi8>) {scale = 7.000000e+00 : f32, zero_point = 0 : i8} -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_6_]] : tensor<1x3x4x4xf32> // CHECK: } func.func @sum_shifts_no_shifts(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {scale = 7.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xi8> - %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {scale = 7.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - return %5 : tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%arg0, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + %4 = xten_nn.quantize(%3 : tensor<1x3x4x4xf32>) {scale = 7.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xi8> + %5 = xten_nn.dequantize(%4 : tensor<1x3x4x4xi8>) {scale = 7.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xf32> + %6 = "tosa.mul"(%5, %1, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + return %6 : tensor<1x3x4x4xf32> } } // -- module attributes {} { -// CHECK-LABEL: func.func @multiple_q_uses( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x4xi8>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xi8>) { -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> -// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_6]], %[[VAL_3]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_1]], %[[VAL_5]] : (tensor<1x3x4x4xi8>, tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xi8> -// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x3x4x4xf32>, tensor<1x3x4x4xi8> +// CHECK-LABEL: func.func @multiple_q_uses +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x3x4x4xi8>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xi8>) { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.cast [[VAR_3_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = tosa.add [[PARAM_1_]], [[VAR_4_]] : (tensor<1x3x4x4xi8>, tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xi8> +// CHECK: return [[VAR_6_]], [[VAR_7_]] : tensor<1x3x4x4xf32>, tensor<1x3x4x4xi8> // CHECK: } func.func @multiple_q_uses(%arg0: tensor<1x3x4x4xf32>, %arg1: tensor<1x3x4x4xi8>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xi8>) { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = "tosa.cast"(%2) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> - %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%arg0, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> + %5 = "tosa.cast"(%4) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> + %6 = "tosa.mul"(%5, %1, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> // quantized output is used twice, so we cannot replace the casts here we want strictly Q->DQ - %6 = "tosa.add"(%arg1, %3) : (tensor<1x3x4x4xi8>, tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xi8> - return %5, %6: tensor<1x3x4x4xf32>, tensor<1x3x4x4xi8> + %7 = "tosa.add"(%arg1, %4) : (tensor<1x3x4x4xi8>, tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xi8> + return %6, %7: tensor<1x3x4x4xf32>, tensor<1x3x4x4xi8> } } // -- module attributes {} { -// CHECK-LABEL: func.func @multiple_dq_uses( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_5:.*]] = xten_nn.quantize(%[[VAL_4]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> -// CHECK: %[[VAL_6:.*]] = xten_nn.dequantize(%[[VAL_5]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_6]], %[[VAL_3]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_1]], %[[VAL_6]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @multiple_dq_uses +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = xten_nn.quantize([[VAR_3_]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_5_:%.+]] = xten_nn.dequantize([[VAR_4_]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = tosa.add [[PARAM_1_]], [[VAR_5_]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_6_]], [[VAR_7_]] : tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> // CHECK: } func.func @multiple_dq_uses(%arg0: tensor<1x3x4x4xf32>, %arg1: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = "tosa.cast"(%2) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> - %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%arg0, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> + %5 = "tosa.cast"(%4) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> + %6 = "tosa.mul"(%5, %1, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> // dequantize output is used twice, meaning the muls will not fold. - %6 = "tosa.add"(%arg1, %4) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> - return %5, %6: tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> + %7 = "tosa.add"(%arg1, %5) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + return %6, %7: tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> } } @@ -330,44 +356,45 @@ module attributes {} { func.func @multiple_dq_uses_muls_fold(%arg0: tensor<1x3x4x4xf32>, %arg1: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { %0 = "tosa.const"() {value = dense<1.0> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<1.0> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = "tosa.cast"(%2) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> - %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%arg0, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> + %5 = "tosa.cast"(%4) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> + %6 = "tosa.mul"(%5, %1, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> // dequantize output is used twice, but the scale factor is one so the muls fold away - %6 = "tosa.add"(%arg1, %4) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> - return %5, %6: tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> + %7 = "tosa.add"(%arg1, %5) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + return %6, %7: tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> } } // -- module attributes {} { -// CHECK-LABEL: func.func @multiple_mul_q_uses( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_5:.*]] = xten_nn.quantize(%[[VAL_4]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> -// CHECK: %[[VAL_6:.*]] = xten_nn.dequantize(%[[VAL_5]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_6]], %[[VAL_3]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_1]], %[[VAL_4]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @multiple_mul_q_uses +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = xten_nn.quantize([[VAR_3_]] : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_5_:%.+]] = xten_nn.dequantize([[VAR_4_]] : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = tosa.add [[PARAM_1_]], [[VAR_3_]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_6_]], [[VAR_7_]] : tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> // CHECK: } func.func @multiple_mul_q_uses(%arg0: tensor<1x3x4x4xf32>, %arg1: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = "tosa.cast"(%2) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> - %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %2 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%arg0, %0, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> + %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> + %5 = "tosa.cast"(%4) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> + %6 = "tosa.mul"(%5, %1, %2) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> // quantize mul output is used twice this is not allowed, it should only be used by the QuantizeOp - %6 = "tosa.add"(%arg1, %2) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> - return %5, %6: tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> + %7 = "tosa.add"(%arg1, %3) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + return %6, %7: tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> } } - // -- module attributes {} { @@ -382,10 +409,11 @@ module attributes {} { func.func @multiple_mul_dq_uses(%arg0: tensor<1x3x4x4xf32>, %arg1: tensor<1x3x4x4xf32>) -> (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = "tosa.mul"(%arg0, %0, %shift) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %3 = "tosa.cast"(%2) : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> %4 = "tosa.cast"(%3) : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %5 = "tosa.mul"(%4, %1, %shift) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> // dequantize mul output is used twice, but that can be expected. %6 = "tosa.add"(%arg1, %5) : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> return %5, %6: tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32> @@ -395,24 +423,26 @@ module attributes {} { // -- module attributes {} { -// CHECK-LABEL: func.func @broadcast_mul_on_quantize( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3xf32>) -> tensor<4x3xf32> { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<4x3xf32>}> : () -> tensor<4x3xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_3:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> -// CHECK: %[[VAL_4:.*]] = xten_nn.quantize(%[[VAL_3]] : tensor<4x3xf32>) {shift = 0 : si32} -> tensor<4x3xi8> -// CHECK: %[[VAL_5:.*]] = xten_nn.dequantize(%[[VAL_4]] : tensor<4x3xi8>) {shift = 0 : si32} -> tensor<4x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_2]] {shift = 0 : i8} : (tensor<4x3xf32>, tensor<1x1xf32>) -> tensor<4x3xf32> -// CHECK: return %[[VAL_6]] : tensor<4x3xf32> +// CHECK-LABEL: func.func @broadcast_mul_on_quantize +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3xf32>) -> tensor<4x3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<4x3xf32>}> : () -> tensor<4x3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3xf32>, tensor<4x3xf32>, tensor<1xi8>) -> tensor<4x3xf32> +// CHECK: [[VAR_4_:%.+]] = xten_nn.quantize([[VAR_3_]] : tensor<4x3xf32>) {shift = 0 : si32} -> tensor<4x3xi8> +// CHECK: [[VAR_5_:%.+]] = xten_nn.dequantize([[VAR_4_]] : tensor<4x3xi8>) {shift = 0 : si32} -> tensor<4x3xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_1_]], [[VAR_2_]] : (tensor<4x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<4x3xf32> +// CHECK: return [[VAR_6_]] : tensor<4x3xf32> // CHECK: } func.func @broadcast_mul_on_quantize(%arg0: tensor<1x3xf32>) -> tensor<4x3xf32> { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<4x3xf32>} : () -> tensor<4x3xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // Mul cannot be folded because the output shape changes w.r.t input due to broadcasting. - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %2 = "tosa.mul"(%arg0, %0, %shift) : (tensor<1x3xf32>, tensor<4x3xf32>, tensor<1xi8>) -> tensor<4x3xf32> %3 = "tosa.cast"(%2) : (tensor<4x3xf32>) -> tensor<4x3xi8> %4 = "tosa.cast"(%3) : (tensor<4x3xi8>) -> tensor<4x3xf32> - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<4x3xf32>, tensor<1x1xf32>) -> tensor<4x3xf32> + %5 = "tosa.mul"(%4, %1, %shift) : (tensor<4x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<4x3xf32> return %5 : tensor<4x3xf32> } } @@ -420,24 +450,26 @@ module attributes {} { // -- module attributes {} { -// CHECK-LABEL: func.func @broadcast_mul_on_dequantize( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3xf32>) -> tensor<4x3xf32> { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<4x3xf32>}> : () -> tensor<4x3xf32> -// CHECK: %[[VAL_3:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> -// CHECK: %[[VAL_4:.*]] = xten_nn.quantize(%[[VAL_3]] : tensor<1x3xf32>) {shift = 0 : si32} -> tensor<1x3xi8> -// CHECK: %[[VAL_5:.*]] = xten_nn.dequantize(%[[VAL_4]] : tensor<1x3xi8>) {shift = 0 : si32} -> tensor<1x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> -// CHECK: return %[[VAL_6]] : tensor<4x3xf32> +// CHECK-LABEL: func.func @broadcast_mul_on_dequantize +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3xf32>) -> tensor<4x3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<4x3xf32>}> : () -> tensor<4x3xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<1x3xf32> +// CHECK: [[VAR_4_:%.+]] = xten_nn.quantize([[VAR_3_]] : tensor<1x3xf32>) {shift = 0 : si32} -> tensor<1x3xi8> +// CHECK: [[VAR_5_:%.+]] = xten_nn.dequantize([[VAR_4_]] : tensor<1x3xi8>) {shift = 0 : si32} -> tensor<1x3xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3xf32>, tensor<4x3xf32>, tensor<1xi8>) -> tensor<4x3xf32> +// CHECK: return [[VAR_6_]] : tensor<4x3xf32> // CHECK: } func.func @broadcast_mul_on_dequantize(%arg0: tensor<1x3xf32>) -> tensor<4x3xf32> { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1xf32>} : () -> tensor<1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<4x3xf32>} : () -> tensor<4x3xf32> - %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = "tosa.mul"(%arg0, %0, %shift) : (tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<1x3xf32> %3 = "tosa.cast"(%2) : (tensor<1x3xf32>) -> tensor<1x3xi8> %4 = "tosa.cast"(%3) : (tensor<1x3xi8>) -> tensor<1x3xf32> // Mul cannot be folded because the output shape changes w.r.t input due to broadcasting. - %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + %5 = "tosa.mul"(%4, %1, %shift) : (tensor<1x3xf32>, tensor<4x3xf32>, tensor<1xi8>) -> tensor<4x3xf32> return %5 : tensor<4x3xf32> } } @@ -454,12 +486,13 @@ module attributes {} { func.func @sort_mul_operands(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // The MULs have the constants on operand(0) the SortCommutativeOperands pass should move them to // operand(1) and the MUL folding should occur. - %2 = "tosa.mul"(%0, %arg0) {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + %2 = "tosa.mul"(%0, %arg0, %shift) : (tensor<1x1x1x1xf32>, tensor<1x3x4x4xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> - %5 = "tosa.mul"(%1, %4) {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + %5 = "tosa.mul"(%1, %4, %shift) : (tensor<1x1x1x1xf32>, tensor<1x3x4x4xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> return %5 : tensor<1x3x4x4xf32> } } @@ -477,10 +510,11 @@ module attributes {} { %0 = "tosa.const"() {value = dense<1.280000e+02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<7.812500e-03> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %2 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x3x4x4xf32>} : () -> tensor<1x3x4x4xf32> - %3 = "tosa.mul"(%0, %2) {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %3 = "tosa.mul"(%0, %2, %shift) : (tensor<1x1x1x1xf32>, tensor<1x3x4x4xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> %4 = xten_nn.quantize(%3 : tensor<1x3x4x4xf32>) {shift = 0 : si32} -> tensor<1x3x4x4xi8> %5 = xten_nn.dequantize(%4 : tensor<1x3x4x4xi8>) {shift = 0 : si32} -> tensor<1x3x4x4xf32> - %6 = "tosa.mul"(%5, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %6 = "tosa.mul"(%5, %1, %shift) : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> return %6 : tensor<1x3x4x4xf32> } } diff --git a/test/Conversion/XTenNNToTosa/quantization.mlir b/test/Conversion/XTenNNToTosa/quantization.mlir index d73c23c4..4272676a 100644 --- a/test/Conversion/XTenNNToTosa/quantization.mlir +++ b/test/Conversion/XTenNNToTosa/quantization.mlir @@ -19,13 +19,14 @@ module attributes{} { // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x3x4x4xf32>}> : () -> tensor<1x3x4x4xf32> // CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1> : tensor<1x3x4x4xi32>}> : () -> tensor<1x3x4x4xi32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: [[VAR_4_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_6_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi8> +// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_5_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]], [[VAR_4_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_7_:%.+]] = tosa.add [[VAR_6_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_7_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi8> // CHECK: [[VAR_9_:%.+]] = tosa.cast [[VAR_8_]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> // CHECK: [[VAR_10_:%.+]] = tosa.sub [[VAR_9_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_11_:%.+]] = tosa.mul [[VAR_10_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_11_:%.+]] = tosa.mul [[VAR_10_]], [[VAR_0_]], [[VAR_4_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> // CHECK: return [[VAR_11_]] : tensor<1x3x4x4xf32> // CHECK: } func.func @explicit_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { @@ -44,13 +45,14 @@ module attributes{} { // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x3x4x4xbf16>}> : () -> tensor<1x3x4x4xbf16> // CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1> : tensor<1x3x4x4xi32>}> : () -> tensor<1x3x4x4xi32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16> -// CHECK: [[VAR_4_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>) -> tensor<1x3x4x4xbf16> -// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_6_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi8> +// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_5_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]], [[VAR_4_]] : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>, tensor<1xi8>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_7_:%.+]] = tosa.add [[VAR_6_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_7_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi8> // CHECK: [[VAR_9_:%.+]] = tosa.cast [[VAR_8_]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xbf16> // CHECK: [[VAR_10_:%.+]] = tosa.sub [[VAR_9_]], [[VAR_1_]] : (tensor<1x3x4x4xbf16>, tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> -// CHECK: [[VAR_11_:%.+]] = tosa.mul [[VAR_10_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_11_:%.+]] = tosa.mul [[VAR_10_]], [[VAR_0_]], [[VAR_4_]] : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>, tensor<1xi8>) -> tensor<1x3x4x4xbf16> // CHECK: return [[VAR_11_]] : tensor<1x3x4x4xbf16> // CHECK: } func.func @explicit_case_bf16(%arg0: tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> { @@ -70,13 +72,14 @@ module attributes{} { // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x3x4x4xbf16>}> : () -> tensor<1x3x4x4xbf16> // CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1> : tensor<1x3x4x4xi32>}> : () -> tensor<1x3x4x4xi32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16> -// CHECK: [[VAR_4_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>) -> tensor<1x3x4x4xbf16> -// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_6_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xui16> +// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_5_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]], [[VAR_4_]] : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>, tensor<1xi8>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_7_:%.+]] = tosa.add [[VAR_6_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_7_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xui16> // CHECK: [[VAR_9_:%.+]] = tosa.cast [[VAR_8_]] : (tensor<1x3x4x4xui16>) -> tensor<1x3x4x4xbf16> // CHECK: [[VAR_10_:%.+]] = tosa.sub [[VAR_9_]], [[VAR_1_]] : (tensor<1x3x4x4xbf16>, tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> -// CHECK: [[VAR_11_:%.+]] = tosa.mul [[VAR_10_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_11_:%.+]] = tosa.mul [[VAR_10_]], [[VAR_0_]], [[VAR_4_]] : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>, tensor<1xi8>) -> tensor<1x3x4x4xbf16> // CHECK: return [[VAR_11_]] : tensor<1x3x4x4xbf16> // CHECK: } func.func @explicit_case_bf16_to_uint16(%arg0: tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> { @@ -93,12 +96,13 @@ module attributes{} { // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<8.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.250000e-01> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> -// CHECK: [[VAR_3_:%.+]] = tosa.cast [[VAR_2_]] : (tensor<2x3xf32>) -> tensor<2x3xi4> -// CHECK: [[VAR_4_:%.+]] = tosa.cast [[VAR_3_]] : (tensor<2x3xi4>) -> tensor<2x3xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> -// CHECK: return [[VAR_5_]] : tensor<2x3xf32> -// CHECK: } +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[VAR_2_]] : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.cast [[VAR_3_]] : (tensor<2x3xf32>) -> tensor<2x3xi4> +// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<2x3xi4>) -> tensor<2x3xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_0_]], [[VAR_2_]] : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32> +// CHECK: return [[VAR_6_]] : tensor<2x3xf32> +// CHECK: } func.func @small_tensors(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { %0 = xten_nn.quantize(%arg0 : tensor<2x3xf32>) {scale = 8.0 : f32, shift = 3 : si32, zero_point = 0 : i4} -> tensor<2x3xi4> %1 = xten_nn.dequantize(%0 : tensor<2x3xi4>) {scale = 8.0 : f32, shift = 3 : si32, zero_point = 0 : i4} -> tensor<2x3xf32> @@ -111,10 +115,11 @@ module attributes{} { module attributes{} { // CHECK-LABEL: func.func @quantize_case // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> { -// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_2_:%.+]] = tosa.cast [[VAR_1_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> -// CHECK: return [[VAR_2_]] : tensor<1x3x4x4xi8> +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.cast [[VAR_2_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> +// CHECK: return [[VAR_3_]] : tensor<1x3x4x4xi8> // CHECK: } func.func @quantize_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> { %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i8} -> tensor<1x3x4x4xi8> @@ -125,12 +130,13 @@ module attributes{} { // -- module attributes{} { -// CHECK-LABEL: func.func @dequantize_case( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.cast %[[VAL_0]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK-DAG: return %[[VAL_4]] : tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @dequantize_case +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_1_]], [[VAR_0_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_3_]] : tensor<1x3x4x4xf32> func.func @dequantize_case(%arg0: tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> { %0 = xten_nn.dequantize(%arg0 : tensor<1x3x4x4xi8>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i8} -> tensor<1x3x4x4xf32> return %0 : tensor<1x3x4x4xf32> @@ -144,12 +150,12 @@ module attributes{} { // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_3_:%.+]] = tosa.cast [[VAR_2_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi16> -// CHECK: [[VAR_4_:%.+]] = tosa.cast [[VAR_3_]] : (tensor<1x3x4x4xi16>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return [[VAR_5_]] : tensor<1x3x4x4xf32> -// CHECK: } +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.cast [[VAR_3_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi16> +// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xi16>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_6_]] : tensor<1x3x4x4xf32> func.func @i16_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i16} -> tensor<1x3x4x4xi16> %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi16>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i16} -> tensor<1x3x4x4xf32> @@ -168,14 +174,15 @@ module attributes{} { // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x3x4x4xf32>}> : () -> tensor<1x3x4x4xf32> // CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1> : tensor<1x3x4x4xi32>}> : () -> tensor<1x3x4x4xi32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: [[VAR_4_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> -// CHECK: [[VAR_7_:%.+]] = tosa.cast [[VAR_6_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi16> -// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_7_]] : (tensor<1x3x4x4xi16>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_9_:%.+]] = tosa.sub [[VAR_8_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_10_:%.+]] = tosa.mul [[VAR_9_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return [[VAR_10_]] : tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_5_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_3_]], [[VAR_4_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_7_:%.+]] = tosa.add [[VAR_6_]], [[VAR_2_]] : (tensor<1x3x4x4xi32>, tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi32> +// CHECK: [[VAR_8_:%.+]] = tosa.cast [[VAR_7_]] : (tensor<1x3x4x4xi32>) -> tensor<1x3x4x4xi16> +// CHECK: [[VAR_9_:%.+]] = tosa.cast [[VAR_8_]] : (tensor<1x3x4x4xi16>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_10_:%.+]] = tosa.sub [[VAR_9_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_11_:%.+]] = tosa.mul [[VAR_10_]], [[VAR_0_]], [[VAR_4_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_11_]] : tensor<1x3x4x4xf32> // CHECK: } func.func @i16_case_zero_point(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, zero_point = 1 : i16} -> tensor<1x3x4x4xi16> @@ -191,11 +198,12 @@ module attributes{} { // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_3_:%.+]] = tosa.cast [[VAR_2_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi12> -// CHECK: [[VAR_4_:%.+]] = tosa.cast [[VAR_3_]] : (tensor<1x3x4x4xi12>) -> tensor<1x3x4x4xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK: return [[VAR_5_]] : tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.cast [[VAR_3_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi12> +// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xi12>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_5_]], [[VAR_0_]], [[VAR_2_]] : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_6_]] : tensor<1x3x4x4xf32> // CHECK: } func.func @i12_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i12} -> tensor<1x3x4x4xi12>