Skip to content

Commit be0f046

Browse files
majiddadashicopybara-github
authored andcommitted
Add support for kTfLiteInt2 to Dequantize kernels.
This change enables the Dequantize and PerChannelDequantize operations to handle 2-bit integer inputs (`kTfLiteInt2`). It includes logic to unpack the packed 2-bit integers into int8_t before performing the dequantization and adds new test cases for both per-tensor and per-channel dequantization with kTfLiteInt2. PiperOrigin-RevId: 822207279
1 parent d13d4bf commit be0f046

File tree

5 files changed

+64
-9
lines changed

5 files changed

+64
-9
lines changed

tflite/core/kernels/register.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
179179
/* max_version = */ 8);
180180
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
181181
/* min_version = */ 1,
182-
/* max_version = */ 6);
182+
/* max_version = */ 7);
183183
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
184184
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
185185
/* min_version = */ 1,

tflite/kernels/dequantize.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
5757

5858
TF_LITE_ENSURE(context, op_context.input != nullptr);
5959

60-
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteInt4 ||
60+
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteInt2 ||
61+
op_context.input->type == kTfLiteInt4 ||
6162
op_context.input->type == kTfLiteUInt8 ||
6263
op_context.input->type == kTfLiteInt8 ||
6364
op_context.input->type == kTfLiteInt16 ||

tflite/kernels/dequantize.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,24 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context,
7272
per_channel_op_params.zero_point = zero_points.data();
7373
}
7474
const int8_t* input_data;
75-
const size_t bytes_unpacked = input->bytes * 2;
75+
size_t bytes_unpacked;
76+
if (input->type == kTfLiteInt2) {
77+
bytes_unpacked = input->bytes * 4;
78+
} else {
79+
bytes_unpacked = input->bytes * 2;
80+
}
7681
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
7782

7883
if (input->type == kTfLiteInt4) {
7984
tflite::tensor_utils::UnpackPackedIntToInt8(
8085
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
8186
/*bit_width=*/4, unpacked_input_data.get());
8287
input_data = unpacked_input_data.get();
88+
} else if (input->type == kTfLiteInt2) {
89+
tflite::tensor_utils::UnpackPackedIntToInt8(
90+
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
91+
/*bit_width=*/2, unpacked_input_data.get());
92+
input_data = unpacked_input_data.get();
8393
} else {
8494
input_data = GetTensorData<int8_t>(input);
8595
}
@@ -91,6 +101,7 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context,
91101
GetTensorData<uint8_t>(input), GetTensorShape(output),
92102
GetTensorData<float>(output));
93103
break;
104+
case kTfLiteInt2:
94105
case kTfLiteInt4:
95106
case kTfLiteInt8:
96107
reference_ops::PerChannelDequantize<int8_t>(
@@ -115,7 +126,12 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
115126
op_params.zero_point = input->params.zero_point;
116127
op_params.scale = input->params.scale;
117128
const int8_t* input_data;
118-
const size_t bytes_unpacked = input->bytes * 2;
129+
size_t bytes_unpacked;
130+
if (input->type == kTfLiteInt2) {
131+
bytes_unpacked = input->bytes * 4;
132+
} else {
133+
bytes_unpacked = input->bytes * 2;
134+
}
119135
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
120136

121137
if (input->type == kTfLiteInt4) {
@@ -124,6 +140,12 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
124140
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
125141
/*bit_width=*/4, unpacked_input_data.get());
126142
input_data = unpacked_input_data.get();
143+
} else if (input->type == kTfLiteInt2) {
144+
// Use GetTensorShape(input).FlatSize() for num_elements.
145+
tflite::tensor_utils::UnpackPackedIntToInt8(
146+
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
147+
/*bit_width=*/2, unpacked_input_data.get());
148+
input_data = unpacked_input_data.get();
127149
} else {
128150
input_data = GetTensorData<int8_t>(input);
129151
}
@@ -140,6 +162,7 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
140162
GetTensorShape(output), GetTensorData<float>(output));
141163
}
142164
break;
165+
case kTfLiteInt2:
143166
case kTfLiteInt4:
144167
case kTfLiteInt8:
145168
if (kernel_type == kReference) {

tflite/kernels/dequantize_test.cc

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,8 @@ limitations under the License.
1919

2020
#include <gmock/gmock.h>
2121
#include <gtest/gtest.h>
22-
#include "absl/memory/memory.h"
2322
#include "Eigen/Core" // from @eigen_archive
24-
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
25-
#include "tflite/core/api/op_resolver.h"
2623
#include "tflite/core/interpreter.h"
27-
#include "tflite/kernels/internal/types.h"
2824
#include "tflite/kernels/test_util.h"
2925
#include "tflite/schema/schema_generated.h"
3026

@@ -75,6 +71,15 @@ class DequantizeOpModel : public SingleOpModel {
7571
data_int8.data() + data_int8.size());
7672
}
7773

74+
template <typename T>
75+
void SetInputInt2(int input, const std::vector<T> data) {
76+
auto non_const = *const_cast<std::vector<T>*>(&data);
77+
std::vector<int8_t> data_int8(non_const.size());
78+
std::copy(non_const.begin(), non_const.end(), data_int8.begin());
79+
PopulateTensor2bit(input, 0, data_int8.data(),
80+
data_int8.data() + data_int8.size());
81+
}
82+
7883
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
7984

8085
protected:
@@ -92,6 +97,15 @@ TEST(DequantizeOpTest, Int4) {
9297
ElementsAreArray(ArrayFloatNear({4, 3.5, -3, -3.5})));
9398
}
9499

100+
TEST(DequantizeOpTest, Int2) {
101+
DequantizeOpModel m(TensorType_INT2, {1, 4}, 0.5, -1, 6);
102+
103+
m.SetInputInt2<int8_t>(0, {1, 0, -1, -2});
104+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
105+
EXPECT_THAT(m.GetOutput(),
106+
ElementsAreArray(ArrayFloatNear({1.0, 0.5, 0.0, -0.5})));
107+
}
108+
95109
TEST(DequantizeOpTest, Uint8) {
96110
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
97111
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1);
@@ -185,5 +199,22 @@ TEST(DequantizePerChannelOpTest, Int8) {
185199
{-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64})));
186200
}
187201

202+
TEST(DequantizePerChannelOpTest, Int2) {
203+
// scales={0.5, 1.0}, zero_points={-1, 0}, channel_dim=0
204+
DequantizePerChannelOpModel m(TensorType_INT2, {2, 2}, {0.5, 1.0}, {-1, 0}, 0,
205+
6);
206+
m.SetInputInt2<int8_t>(0, {1, 0, -1, -2});
207+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
208+
// Dequantization formula: (val - zp) * scale
209+
// Channel 0: scale=0.5, zp=-1.
210+
// val=1: (1 - (-1)) * 0.5 = 1.0
211+
// val=0: (0 - (-1)) * 0.5 = 0.5
212+
// Channel 1: scale=1.0, zp=0
213+
// val=-1: (-1 - 0) * 1.0 = -1.0
214+
// val=-2: (-2 - 0) * 1.0 = -2.0
215+
EXPECT_THAT(m.GetOutput(),
216+
ElementsAreArray(ArrayFloatNear({1.0, 0.5, -1.0, -2.0})));
217+
}
218+
188219
} // namespace
189220
} // namespace tflite

tflite/kernels/register_ref.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
380380
/* max_version = */ 8);
381381
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(),
382382
/* min_version = */ 1,
383-
/* max_version = */ 6);
383+
/* max_version = */ 7);
384384
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU_REF());
385385
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM_REF(),
386386
/* min_version = */ 1,

0 commit comments

Comments
 (0)