diff --git a/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/conversions/tosa/tests/binary.mlir index bf440f198f..60342539ef 100644 --- a/stablehlo/conversions/tosa/tests/binary.mlir +++ b/stablehlo/conversions/tosa/tests/binary.mlir @@ -52,10 +52,10 @@ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi // CHECK-LABEL: @dot_vector_vector func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<> : tensor<0xindex>} : () -> !tosa.shape<0> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] @@ -65,10 +65,10 @@ func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> te // CHECK-LABEL: @dot_vector_matrix func.func @dot_vector_matrix(%arg0 : tensor<2xf32>, %arg1 : tensor<2x3xf32>) -> tensor<3xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] @@ -78,10 +78,10 @@ func.func @dot_vector_matrix(%arg0 : tensor<2xf32>, %arg1 : tensor<2x3xf32>) -> // CHECK-LABEL: @dot_matrix_vector func.func @dot_matrix_vector(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3xf32>) -> tensor<2xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] @@ -91,10 +91,10 @@ func.func @dot_matrix_vector(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3xf32>) -> // CHECK-LABEL: @dot_matrix_matrix func.func @dot_matrix_matrix(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>) -> tensor<2x4xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] @@ -104,10 +104,10 @@ func.func @dot_matrix_matrix(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>) - // CHECK-LABEL: @dot_general_vector_vector func.func @dot_general_vector_vector(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<> : tensor<0xindex>} : () -> !tosa.shape<0> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] @@ -117,10 +117,10 @@ func.func @dot_general_vector_vector(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) // CHECK-LABEL: @dot_general_vector_matrix func.func @dot_general_vector_matrix(%arg0: tensor<2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] @@ -130,10 +130,10 @@ func.func @dot_general_vector_matrix(%arg0: tensor<2xf32>, %arg1: tensor<2x3xf32 // CHECK-LABEL: @dot_general_matrix_vector func.func @dot_general_matrix_vector(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] @@ -143,10 +143,10 @@ func.func @dot_general_matrix_vector(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32 // CHECK-LABEL: @dot_general_matrix_matrix func.func @dot_general_matrix_matrix(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]] - // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]] // CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]] diff --git a/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir b/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir index e700ebde9a..5c198980e7 100644 --- a/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir +++ b/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir @@ -4,10 +4,17 @@ // CHECK-LABEL: @add func.func @add(%arg0 : tensor<2x2x!quant.uniform>, %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[SHIFT50:.+]] = "tosa.const"() <{values = dense<50> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT11:.+]] = "tosa.const"() <{values = dense<11> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT13:.+]] = "tosa.const"() <{values = dense<13> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1431655765> : tensor<1xi32>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_2]], %[[SHIFT13]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT11]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.add %[[V0]], %[[V1]] : tensor<2x2xi32> - // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_unsigned = false, input_zp = 0 : i32, multiplier = array, output_unsigned = false, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]], %[[MULTIPLIER_1]], %[[SHIFT50]], %[[ZP_0]], %[[ZP_MINUS_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> @@ -18,10 +25,17 @@ func.func @add(%arg0 : tensor<2x2x!quant.uniform>, // CHECK-LABEL: @sub func.func @sub(%arg0 : tensor<2x2x!quant.uniform>, %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[SHIFT50:.+]] = "tosa.const"() <{values = dense<50> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT11:.+]] = "tosa.const"() <{values = dense<11> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT13:.+]] = "tosa.const"() <{values = dense<13> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1431655765> : tensor<1xi32>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_2]], %[[SHIFT13]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT11]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.subtract %[[V0]], %[[V1]] : tensor<2x2xi32> - // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_unsigned = false, input_zp = 0 : i32, multiplier = array, output_unsigned = false, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]], %[[MULTIPLIER_1]], %[[SHIFT50]], %[[ZP_0]], %[[ZP_MINUS_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> @@ -32,10 +46,16 @@ func.func @sub(%arg0 : tensor<2x2x!quant.uniform>, // CHECK-LABEL: @mul func.func @mul(%arg0 : tensor<2x2x!quant.uniform>, %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[SHIFT37:.+]] = "tosa.const"() <{values = dense<37> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT30:.+]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1717986918> : tensor<1xi32>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_1]], %[[SHIFT30]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT30]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.multiply %[[V0]], %[[V1]] : tensor<2x2xi32> - // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_unsigned = false, input_zp = 0 : i32, multiplier = array, output_unsigned = false, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]], %[[MULTIPLIER_2]], %[[SHIFT37]], %[[ZP_0]], %[[ZP_MINUS_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> @@ -46,10 +66,18 @@ func.func @mul(%arg0 : tensor<2x2x!quant.uniform>, // CHECK-LABEL: @div func.func @div(%arg0 : tensor<2x2x!quant.uniform>, %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -2 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[SHIFT37:.+]] = "tosa.const"() <{values = dense<37> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT30:.+]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1717986918> : tensor<1xi32>}> + // CHECK-DAG: %[[ZP_MINUS_3:.+]] = "tosa.const"() <{values = dense<-3> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_2:.+]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_1]], %[[SHIFT30]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT30]], %[[ZP_MINUS_2]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.divide %[[V0]], %[[V1]] : tensor<2x2xi32> - // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_unsigned = false, input_zp = 0 : i32, multiplier = array, output_unsigned = false, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]], %[[MULTIPLIER_2]], %[[SHIFT37]], %[[ZP_0]], %[[ZP_MINUS_3]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> @@ -60,10 +88,19 @@ func.func @div(%arg0 : tensor<2x2x!quant.uniform>, // CHECK-LABEL: @max func.func @max(%arg0 : tensor<2x2x!quant.uniform>, %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -2 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[ZP_MINUS_3:.+]] = "tosa.const"() <{values = dense<-3> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT51:.+]] = "tosa.const"() <{values = dense<51> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_2:.+]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT10:.+]] = "tosa.const"() <{values = dense<10> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1431655765> : tensor<1xi32>}> + // CHECK-DAG: %[[SHIFT12:.+]] = "tosa.const"() <{values = dense<12> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_2]], %[[SHIFT12]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT10]], %[[ZP_MINUS_2]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.maximum %[[V0]], %[[V1]] : tensor<2x2xi32> - // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_unsigned = false, input_zp = 0 : i32, multiplier = array, output_unsigned = false, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]], %[[MULTIPLIER_1]], %[[SHIFT51]], %[[ZP_0]], %[[ZP_MINUS_3]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> @@ -74,10 +111,19 @@ func.func @max(%arg0 : tensor<2x2x!quant.uniform>, // CHECK-LABEL: @min func.func @min(%arg0 : tensor<2x2x!quant.uniform>, %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -2 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[ZP_MINUS_3:.+]] = "tosa.const"() <{values = dense<-3> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT51:.+]] = "tosa.const"() <{values = dense<51> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_2:.+]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT10:.+]] = "tosa.const"() <{values = dense<10> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1431655765> : tensor<1xi32>}> + // CHECK-DAG: %[[SHIFT12:.+]] = "tosa.const"() <{values = dense<12> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_2]], %[[SHIFT12]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT10]], %[[ZP_MINUS_2]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.minimum %[[V0]], %[[V1]] : tensor<2x2xi32> - // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_unsigned = false, input_zp = 0 : i32, multiplier = array, output_unsigned = false, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]], %[[MULTIPLIER_1]], %[[SHIFT51]], %[[ZP_0]], %[[ZP_MINUS_3]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> @@ -87,9 +133,16 @@ func.func @min(%arg0 : tensor<2x2x!quant.uniform>, // ----- // CHECK-LABEL: @abs func.func @abs(%arg0 : tensor<20x20x!quant.uniform>) -> tensor<20x20x!quant.uniform> { - // CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[ZP_MINUS_128:.+]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT33:.+]] = "tosa.const"() <{values = dense<33> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1431655765> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[SHIFT30:.+]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_2]], %[[SHIFT30]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V1:.+]] = stablehlo.abs %[[V0]] : tensor<20x20xi32> - // CHECK: %[[V3:.+]] = tosa.rescale %[[V1]] {double_round = false, input_unsigned = false, input_zp = 0 : i32, multiplier = array, output_unsigned = false, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V3:.+]] = tosa.rescale %[[V1]], %[[MULTIPLIER_1]], %[[SHIFT33]], %[[ZP_0]], %[[ZP_MINUS_128]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: return %[[V3]] : tensor<20x20x!quant.uniform> %0 = "stablehlo.abs"(%arg0) : (tensor<20x20x!quant.uniform>) -> tensor<20x20x!quant.uniform> return %0 : tensor<20x20x!quant.uniform> @@ -99,8 +152,15 @@ func.func @abs(%arg0 : tensor<20x20x!quant.uniform>) -> tensor // CHECK-LABEL: @compareGE func.func @compareGE(%arg0 : tensor<20x20x!quant.uniform>, %arg1 : tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> { - // CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -2 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[ZP_MINUS_2:.+]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> + // CHECK-DAG: %[[SHIFT10:.+]] = "tosa.const"() <{values = dense<10> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1431655765> : tensor<1xi32>}> + // CHECK-DAG: %[[SHIFT12:.+]] = "tosa.const"() <{values = dense<12> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_MINUS_1:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_2]], %[[SHIFT12]], %[[ZP_MINUS_1]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT10]], %[[ZP_MINUS_2]], %[[ZP_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.compare GE, %[[V0]], %[[V1]], TOTALORDER : // CHECK: return %[[V2]] %0 = stablehlo.compare GE, %arg0, %arg1, TOTALORDER : (tensor<20x20x!quant.uniform>, tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> @@ -109,12 +169,18 @@ func.func @compareGE(%arg0 : tensor<20x20x!quant.uniform>, // ----- // CHECK-LABEL: @compareLT -func.func @compareLT(%arg0 : tensor<20x20x!quant.uniform>, - %arg1 : tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> { - // CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_unsigned = false, input_zp = -1 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_unsigned = false, input_zp = -2 : i32, multiplier = array, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +func.func @compareLT(%arg0 : tensor<20x20x!quant.uniform>, + %arg1 : tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> { + // CHECK-DAG: %[[SHIFT17:.+]] = "tosa.const"() <{values = dense<17> : tensor<1xi8>}> + // CHECK-DAG: %[[MULTIPLIER_1:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> + // CHECK-DAG: %[[MULTIPLIER_2:.+]] = "tosa.const"() <{values = dense<1431655765> : tensor<1xi32>}> + // CHECK-DAG: %[[SHIFT15:.+]] = "tosa.const"() <{values = dense<15> : tensor<1xi8>}> + // CHECK-DAG: %[[ZP16_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> + // CHECK-DAG: %[[ZP32_0:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK: %[[V0:.+]] = tosa.rescale %arg0, %[[MULTIPLIER_2]], %[[SHIFT17]], %[[ZP16_0]], %[[ZP32_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} + // CHECK: %[[V1:.+]] = tosa.rescale %arg1, %[[MULTIPLIER_1]], %[[SHIFT15]], %[[ZP16_0]], %[[ZP32_0]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: %[[V2:.+]] = stablehlo.compare LT, %[[V0]], %[[V1]], TOTALORDER : // CHECK: return %[[V2]] - %0 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<20x20x!quant.uniform>, tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> + %0 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<20x20x!quant.uniform>, tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> return %0 : tensor<20x20xi1> } diff --git a/stablehlo/conversions/tosa/tests/legalize_tosa_rescale_to_stablehlo.mlir b/stablehlo/conversions/tosa/tests/legalize_tosa_rescale_to_stablehlo.mlir index 8a1d6cc513..bdab792db9 100644 --- a/stablehlo/conversions/tosa/tests/legalize_tosa_rescale_to_stablehlo.mlir +++ b/stablehlo/conversions/tosa/tests/legalize_tosa_rescale_to_stablehlo.mlir @@ -3,8 +3,12 @@ // ----- // CHECK-LABEL: @rescale func.func @rescale(%arg0 : tensor<2x2x!quant.uniform>) -> tensor<2x2xi32> { - %0 = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, input_unsigned = false, output_unsigned = false, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : - (tensor<2x2x!quant.uniform>) -> tensor<2x2xi32> + %multiplier = "tosa.const"() {values = dense<1431655765> : tensor<1xi32>} : () -> tensor<1xi32> + %shift = "tosa.const"() {values = dense<13> : tensor<1xi8>} : () -> tensor<1xi8> + %input_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : + (tensor<2x2x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<2x2xi32> // convert input quantized type to storage type // CHECK-DAG: %[[arg:.+]] = stablehlo.bitcast_convert %arg0 : (tensor<2x2x!quant.uniform>) -> tensor<2x2xi8> diff --git a/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/conversions/tosa/tests/nullary.mlir index d4f9c6d4bf..051084e6a7 100644 --- a/stablehlo/conversions/tosa/tests/nullary.mlir +++ b/stablehlo/conversions/tosa/tests/nullary.mlir @@ -17,8 +17,8 @@ func.func @constant_f64() -> tensor<10xf64> { // CHECK-LABEL: @iota_dimension_0 func.func @iota_dimension_0() -> tensor<4x8xf32> { // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}> : () -> tensor<4x1xf32> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 8]> : vector<2xindex>} : () -> !tosa.shape<2> + // CHECK-SAME{LITERAL}: <{values = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}> : () -> tensor<4x1xf32> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 8]> : vector<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR2:.*]] = tosa.tile %[[VAR0]], %[[VAR1]] %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>) return %0 : tensor<4x8xf32> @@ -27,8 +27,8 @@ func.func @iota_dimension_0() -> tensor<4x8xf32> { // CHECK-LABEL: @iota_dimension_1 func.func @iota_dimension_1() -> tensor<4x8xi32> { // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}> : () -> tensor<1x8xi32> - // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[4, 1]> : vector<2xindex>} : () -> !tosa.shape<2> + // CHECK-SAME{LITERAL}: <{values = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}> : () -> tensor<1x8xi32> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[4, 1]> : vector<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR2:.*]] = tosa.tile %[[VAR0]], %[[VAR1]] %0 = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>) return %0 : tensor<4x8xi32> diff --git a/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/conversions/tosa/tests/unary.mlir index d0b8afb50d..60b64cde23 100644 --- a/stablehlo/conversions/tosa/tests/unary.mlir +++ b/stablehlo/conversions/tosa/tests/unary.mlir @@ -30,7 +30,7 @@ func.func @exponential(%arg : tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: @exponential_minus_one func.func @exponential_minus_one(%arg : tensor<10xf32>) -> tensor<10xf32> { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> // CHECK-DAG: %[[VAR1:.*]] = tosa.exp %arg0 // CHECK-DAG: %[[VAR2:.*]] = tosa.sub %[[VAR1]], %[[VAR0]] %0 = "stablehlo.exponential_minus_one"(%arg) : (tensor<10xf32>) -> tensor<10xf32> @@ -46,7 +46,7 @@ func.func @floor(%arg : tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: @is_finite func.func @is_finite(%arg : tensor<10xf32>) -> tensor<10xi1> { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0x7F800000> + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0x7F800000> // CHECK-DAG: %[[VAR1:.*]] = tosa.abs %arg0 // CHECK-DAG: %[[VAR2:.*]] = tosa.equal %[[VAR1]], %[[VAR0]] // CHECK-DAG: %[[VAR3:.*]] = tosa.logical_not %[[VAR2]] @@ -63,7 +63,7 @@ func.func @log(%arg : tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: @log_plus_one func.func @log_plus_one(%arg : tensor<10xf16>) -> tensor<10xf16> { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> // CHECK-DAG: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] // CHECK-DAG: %[[VAR2:.*]] = tosa.log %[[VAR1]] %0 = "stablehlo.log_plus_one"(%arg) : (tensor<10xf16>) -> tensor<10xf16> @@ -79,8 +79,8 @@ func.func @negate(%arg : tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: @slice func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> { - // CHECK: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> - // CHECK: %[[START:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: %[[START:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK: tosa.slice %arg0, %[[SIZE]], %[[START]] %0 = "stablehlo.slice"(%arg) { start_indices = array, @@ -130,8 +130,8 @@ func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { // CHECK-LABEL: @while func.func @while(%arg0: tensor) -> tensor { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<3> : tensor} - // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<1> : tensor} + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<3> : tensor} + // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<1> : tensor} // CHECK: %[[VAR2:.*]] = tosa.while_loop (%[[ARG1:.+]] = %arg0) : (tensor) -> tensor { // CHECK: %[[VAR3:.*]] = tosa.equal %[[ARG1]], %[[VAR0]] // CHECK: tosa.yield %[[VAR3]] diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll index cbd299781a..ef849ab9a0 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll @@ -25,6 +25,11 @@ Rewrite getScalarInt8Tensor() -> Type [{ return RankedTensorType::get({1}, rewriter.getI8Type()); }]; +Rewrite getScalarTensor(type: Type) -> Type [{ + auto elementType = llvm::cast(type).getElementType(); + return RankedTensorType::get({1}, elementType); +}]; + Rewrite zerosLike(op: Op, type: Type) -> Op [{ auto elementType = llvm::cast(type).getElementType(); llvm::SmallVector outputValue; @@ -75,7 +80,7 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{ // Nullary ops. Pattern => replace op {value = input: Attr<_: Tosa_Tensor>} - with op {value = input}; + with op {values = input}; // Unary ops. Pattern => @@ -126,9 +131,15 @@ Pattern { replace root with logPlusOneResult; }; } -Pattern => - replace op(input : Value<_: Tosa_Tensor>) - with op(input); +Pattern { + let root = op(input : Value); + rewrite root with { + let scalarType = getScalarTensor(inputType); + let zp = zerosLike(root, scalarType); + let negResult = op(input, zp, zp) -> (inputType); + replace root with negResult; + }; +} Pattern => replace op(input : Value<_: Tosa_Tensor>) with op(input); diff --git a/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp b/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp index b7cef8e2b1..ba00f8ff8b 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp @@ -40,19 +40,42 @@ namespace tosa { namespace { +Value buildRescaleMultiplier(bool scale32, OpBuilder &builder, Location loc, + ArrayRef multipliers) { + if (scale32) { + return tosa::getConstTensorInt(builder, loc, multipliers); + } else { + SmallVector vec(multipliers.begin(), multipliers.end()); + return tosa::getConstTensorInt(builder, loc, vec); + } +} + // create a tosa rescale op and return its result value Value buildRescale(PatternRewriter &rewriter, Location loc, ShapedType outputType, Value inputVal, int32_t multiplier, int32_t shift, int64_t inputZp, int64_t outputZp, bool doubleRound, bool scale32, bool perChannel) { + auto multiplierVal = + buildRescaleMultiplier(scale32, rewriter, loc, {multiplier}); + auto shiftVal = tosa::getConstTensorInt(rewriter, loc, + {static_cast(shift)}); + auto inputZpVal = + tosa::createZeroPointTensor(rewriter, loc, inputVal.getType(), inputZp); + assert( + inputZpVal.has_value() && + "buildRescale: Failed to create input zero-point tensor for RescaleOp."); + auto outputZpVal = + tosa::createZeroPointTensor(rewriter, loc, outputType, outputZp); + assert( + outputZpVal.has_value() && + "buildRescale: Failed to create output zero-point tensor for RescaleOp."); + + std::string roundingMode = doubleRound ? "DOUBLE_ROUND" : "SINGLE_ROUND"; + auto rescale_op = rewriter.create( - loc, outputType, inputVal, - rewriter.getI32IntegerAttr(static_cast(inputZp)), - rewriter.getI32IntegerAttr(static_cast(outputZp)), - rewriter.getDenseI32ArrayAttr({multiplier}), - rewriter.getDenseI8ArrayAttr({static_cast(shift)}), - rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(doubleRound), - rewriter.getBoolAttr(perChannel), + loc, outputType, inputVal, multiplierVal, shiftVal, inputZpVal.value(), + outputZpVal.value(), rewriter.getBoolAttr(scale32), + rewriter.getStringAttr(roundingMode), rewriter.getBoolAttr(perChannel), /*input_unsigned=*/rewriter.getBoolAttr(false), /*output_unsigned=*/rewriter.getBoolAttr(false)); @@ -279,6 +302,13 @@ LogicalResult matchAndRewriteBinaryOp(StablehloOp op, PatternRewriter &rewriter, "types only"); } + if (!lhsQType.isSigned() || !rhsQType.isSigned() || !resultQType.isSigned()) { + return rewriter.notifyMatchFailure( + op, + "The conversion supports operands/results with signed storage types " + "only"); + } + double lhsRescaleScale, rhsRescaleScale, resultRescaleScale; rescaleScalesFn(lhsQType, rhsQType, resultQType, lhsRescaleScale, diff --git a/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp b/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp index 11bc22ba41..87b1baeb2d 100644 --- a/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp +++ b/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp @@ -65,10 +65,10 @@ LogicalResult ConvertTosaRescaleToStablehlo::matchAndRewrite( } bool scale32 = op.getScale32(); - bool doubleRound = op.getDoubleRound(); + auto roundingMode = op.getRoundingMode(); bool perChannel = op.getPerChannel(); - if (perChannel || doubleRound || !scale32) { + if (perChannel || roundingMode != "SINGLE_ROUND" || !scale32) { return rewriter.notifyMatchFailure( op, "per_channel, double_round, or scale32=false are not yet supported"); @@ -108,16 +108,50 @@ LogicalResult ConvertTosaRescaleToStablehlo::matchAndRewrite( // construct multiplier, shift constant values from op attrs // for scale32, multiplier is tensor of i32 + + DenseElementsAttr multiplierElems; + if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) + return rewriter.notifyMatchFailure( + op, "requires constant multiplier input values"); + + llvm::SmallVector multiplierValues = llvm::to_vector( + llvm::map_range(multiplierElems.getValues(), + [](IntegerAttr attr) -> int32_t { + return static_cast(attr.getInt()); + })); + + // The shift and multiplier values. + DenseElementsAttr shiftElems; + if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) + return rewriter.notifyMatchFailure(op, + "requires constant shift input values"); + + llvm::SmallVector shiftValues = + llvm::to_vector(shiftElems.getValues()); + Value multiplier = getStablehloConstantOp( - rewriter, loc, DenseElementsAttr::get(i32Type, op.getMultiplier())); + rewriter, loc, DenseElementsAttr::get(i32Type, multiplierValues.front())); Value shift = getStablehloConstantOp( - rewriter, loc, DenseElementsAttr::get(i8Type, op.getShift())); + rewriter, loc, DenseElementsAttr::get(i8Type, shiftValues.front())); + + FailureOr maybeIZp = op.getInputZeroPoint(); + if (failed(maybeIZp)) { + return rewriter.notifyMatchFailure( + op, "requires constant input zero point value"); + } + FailureOr maybeOZp = op.getOutputZeroPoint(); + if (failed(maybeOZp)) { + return rewriter.notifyMatchFailure( + op, "requires constant output zero point value"); + } // construct inputZp and outputZp from op attrs Value inputZpI32 = getStablehloConstantOp( - rewriter, loc, DenseElementsAttr::get(i32Type, op.getInputZpAttr())); + rewriter, loc, + DenseElementsAttr::get(i32Type, {static_cast(*maybeIZp)})); Value outputZpI32 = getStablehloConstantOp( - rewriter, loc, DenseElementsAttr::get(i32Type, op.getOutputZpAttr())); + rewriter, loc, + DenseElementsAttr::get(i32Type, {static_cast(*maybeOZp)})); // construct constant 1, min and max tensors Value onesI64 = getStablehloConstantOp(rewriter, loc,