Skip to content

Commit 2d5f2c0

Browse files
majiddadashicopybara-github
authored andcommitted
Enable lowering from FQ Composite for 2-bit
This also adds an additional test for this lowering. LiteRT-Converter-PiperOrigin-RevId: 820534395
1 parent e08e690 commit 2d5f2c0

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

tflite/converter/quantization/common/quantization_lib/quantization.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class Int8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
5656

5757
// General uniform quantized types. The definitions can be used to specify
5858
// operand's tensor types.
59+
def QI2 : QuantizedType<"Uniform", [2], 1>;
5960
def QI4 : QuantizedType<"Uniform", [4], 1>;
6061
def QUI8 : QuantizedType<"Uniform", [8], 0>;
6162
def QI8 : QuantizedType<"Uniform", [8], 1>;

tflite/converter/tests/lower_quant_annotations.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
func.func private @XlaCallModule_quant.fake_quant.impl_0(tensor<1x28x28x3xf32>) -> tensor<1x28x28x3xf32>
44
func.func private @XlaCallModule_quant.fake_quant.impl_5_0(tensor<2x1x1x1xf32>) -> tensor<2x1x1x1xf32>
55
func.func private @XlaCallModule_quant.fake_quant.impl_17_0(tensor<1x30x30x2xf32>) -> tensor<1x30x30x2xf32>
6+
func.func private @XlaCallModule_quant.fake_quant.impl_i2_0(tensor<1x4xf32>) -> tensor<1x4xf32>
7+
func.func private @XlaCallModule_quant.fake_quant.impl_i2_1(tensor<1x4xf32>) -> tensor<1x4xf32>
68
// CHECK-LABEL: func.func @serving_default
79
func.func @serving_default(%arg0: tensor<1x28x28x3xf32>) -> (tensor<1x30x30x2xf32>) {
810
%cst = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
@@ -22,4 +24,15 @@ func.func @serving_default(%arg0: tensor<1x28x28x3xf32>) -> (tensor<1x30x30x2xf3
2224
// CHECK-OFF: %[[DEQUANT2:.+]] = "tfl.dequantize"(%[[QUANT2]]) : (tensor<1x30x30x2x!quant.uniform<i8:f32, 0.018049469217658043:8>>) -> tensor<1x30x30x2xf32>
2325
%5 = stablehlo.composite "quant.fake_quant" %4 {composite_attributes = {dtype = "i8", narrow_range = false, scale = dense<0.0180494692> : tensor<1xf32>, zero_point = dense<8> : tensor<1xi32>}, decomposition = @XlaCallModule_quant.fake_quant.impl_17_0} : (tensor<1x30x30x2xf32>) -> tensor<1x30x30x2xf32>
2426
return %5 : tensor<1x30x30x2xf32>
27+
}
28+
29+
// CHECK-LABEL: func.func @i2_test
30+
func.func @i2_test(%arg0: tensor<1x4xf32>) -> (tensor<1x4xf32>) {
31+
// CHECK: %[[QUANT0:.+]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x4x!quant.uniform<i2:f32, 1.000000e+00>>}> : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<i2:f32, 1.000000e+00>>
32+
// CHECK: %[[DEQUANT0:.+]] = "tfl.dequantize"(%[[QUANT0]]) : (tensor<1x4x!quant.uniform<i2:f32, 1.000000e+00>>) -> tensor<1x4xf32>
33+
%0 = stablehlo.composite "quant.fake_quant" %arg0 {composite_attributes = {dtype = "i2", narrow_range = false, scale = dense<1.0> : tensor<1xf32>, zero_point = dense<0> : tensor<1xi32>}, decomposition = @XlaCallModule_quant.fake_quant.impl_i2_0} : (tensor<1x4xf32>) -> tensor<1x4xf32>
34+
// CHECK: %[[QUANT1:.+]] = "tfl.quantize"(%[[DEQUANT0]]) <{qtype = tensor<1x4x!quant.uniform<i2<-1:1>:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>}> : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<i2<-1:1>:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>
35+
// CHECK: %[[DEQUANT1:.+]] = "tfl.dequantize"(%[[QUANT1]]) : (tensor<1x4x!quant.uniform<i2<-1:1>:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>) -> tensor<1x4xf32>
36+
%1 = stablehlo.composite "quant.fake_quant" %0 {composite_attributes = {dtype = "i2", narrow_range = true, quantization_dimension = 1 : i32, scale = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>}, decomposition = @XlaCallModule_quant.fake_quant.impl_i2_1} : (tensor<1x4xf32>) -> tensor<1x4xf32>
37+
return %1 : tensor<1x4xf32>
2538
}

tflite/converter/transforms/lower_quant_annotations_helper.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ LogicalResult FillCompositeParams(stablehlo::CompositeOp op,
7171
return failure();
7272
}
7373
std::string dtype = dtype_attr.getValue().str();
74-
if (dtype == "i8") {
75-
num_bits = 8;
74+
if (dtype == "i2") {
75+
num_bits = 2;
7676
is_signed = true;
7777
} else if (dtype == "i4") {
7878
num_bits = 4;
7979
is_signed = true;
80+
} else if (dtype == "i8") {
81+
num_bits = 8;
82+
is_signed = true;
8083
} else {
8184
return failure();
8285
}
@@ -110,7 +113,16 @@ LogicalResult GetStorageParams(unsigned num_bits, bool narrow_range,
110113
bool is_signed, MLIRContext* ctx,
111114
Type& storage_type, int64_t& qmin,
112115
int64_t& qmax) {
113-
if (num_bits <= 4) {
116+
if (num_bits == 2) {
117+
storage_type = IntegerType::get(ctx, 2);
118+
if (is_signed) {
119+
qmin = -2;
120+
qmax = 1;
121+
} else {
122+
qmin = 0;
123+
qmax = 3;
124+
}
125+
} else if (num_bits <= 4) {
114126
storage_type = IntegerType::get(ctx, 4);
115127
if (is_signed) {
116128
qmin = -8;

0 commit comments

Comments
 (0)