Skip to content

Commit cfbf64f

Browse files
majiddadashicopybara-github
authored andcommitted
Add support for int2/int4 in tfl.cast
PiperOrigin-RevId: 820509011
1 parent 8697dbf commit cfbf64f

File tree

7 files changed

+211
-31
lines changed

7 files changed

+211
-31
lines changed

tflite/core/kernels/register.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
176176
/* max_version = */ 2);
177177
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
178178
/* min_version = */ 1,
179-
/* max_version = */ 7);
179+
/* max_version = */ 8);
180180
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
181181
/* min_version = */ 1,
182182
/* max_version = */ 6);

tflite/kernels/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ cc_library(
172172
"@com_google_absl//absl/base",
173173
"@com_google_absl//absl/base:core_headers",
174174
"@com_google_absl//absl/base:no_destructor",
175+
"@com_google_absl//absl/log:absl_check",
176+
"@com_google_absl//absl/log:absl_log",
175177
"@com_google_absl//absl/strings",
176178
"@com_google_absl//absl/strings:str_format",
177179
"@com_google_absl//absl/synchronization",
@@ -1490,11 +1492,14 @@ cc_test(
14901492
tags = ["tflite_nnapi"],
14911493
deps = [
14921494
":cast_test_common",
1495+
":kernel_util",
14931496
":test_main",
14941497
":test_util",
14951498
"//tflite/c:common",
14961499
"//tflite/core/c:c_api_types",
1500+
"//tflite/kernels/internal:tensor_utils_no_eigen",
14971501
"//tflite/schema:schema_fbs",
1502+
"@com_google_absl//absl/random",
14981503
"@com_google_absl//absl/types:span",
14991504
"@com_google_googletest//:gtest",
15001505
"@eigen_archive//:eigen3",

tflite/kernels/cast.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ limitations under the License.
1818
#include <cstddef>
1919
#include <cstdint>
2020
#include <limits>
21+
#include <type_traits>
22+
#include <vector>
2123

2224
#include "Eigen/Core" // from @eigen_archive
2325
#include "tflite/core/c/common.h"
2426
#include "tflite/core/subgraph.h"
2527
#include "tflite/interpreter_options.h"
28+
#include "tflite/kernels/internal/portable_tensor_utils.h"
2629
#include "tflite/kernels/internal/tensor_ctypes.h"
2730
#include "tflite/kernels/kernel_util.h"
2831
#include "tflite/kernels/op_macros.h"
@@ -183,6 +186,19 @@ void copyCastToBFloat16(const Eigen::half* in, Eigen::bfloat16* out,
183186
});
184187
}
185188

189+
TfLiteStatus castInt2ToFloat(TfLiteContext* context, const TfLiteTensor* in,
190+
TfLiteTensor* out, int num_elements) {
191+
const int8_t* in_data = (const int8_t*)in->data.data;
192+
float* out_data = (float*)out->data.data;
193+
std::vector<int8_t> unpacked_temp(num_elements);
194+
tensor_utils::UnpackPackedIntToInt8(in_data, num_elements, /*bit_width=*/2,
195+
unpacked_temp.data());
196+
for (int i = 0; i < num_elements; ++i) {
197+
out_data[i] = static_cast<float>(unpacked_temp[i]);
198+
}
199+
return kTfLiteOk;
200+
}
201+
186202
TfLiteStatus castInt4ToFloat(TfLiteContext* context, const TfLiteTensor* in,
187203
TfLiteTensor* out, int num_elements) {
188204
const int8_t* in_data = (const int8_t*)in->data.data;
@@ -240,6 +256,34 @@ TfLiteStatus castInt4ToFloat(TfLiteContext* context, const TfLiteTensor* in,
240256
return kTfLiteOk;
241257
}
242258

259+
TfLiteStatus castFloatToInt4(const float* in, TfLiteTensor* out,
260+
int num_elements) {
261+
const float min_val = -8.0f;
262+
const float max_val = 7.0f;
263+
std::vector<int8_t> unpacked_temp(num_elements);
264+
for (int i = 0; i < num_elements; ++i) {
265+
unpacked_temp[i] =
266+
static_cast<int8_t>(std::max(min_val, std::min(max_val, in[i])));
267+
}
268+
tensor_utils::PackInt8IntoDenseInt(unpacked_temp.data(), num_elements,
269+
/*bit_width=*/4, (int8_t*)out->data.data);
270+
return kTfLiteOk;
271+
}
272+
273+
TfLiteStatus castFloatToInt2(const float* in, TfLiteTensor* out,
274+
int num_elements) {
275+
const float min_val = -2.0f;
276+
const float max_val = 1.0f;
277+
std::vector<int8_t> unpacked_temp(num_elements);
278+
for (int i = 0; i < num_elements; ++i) {
279+
unpacked_temp[i] =
280+
static_cast<int8_t>(std::max(min_val, std::min(max_val, in[i])));
281+
}
282+
tensor_utils::PackInt8IntoDenseInt(unpacked_temp.data(), num_elements,
283+
/*bit_width=*/2, (int8_t*)out->data.data);
284+
return kTfLiteOk;
285+
}
286+
243287
template <typename FromT>
244288
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
245289
TfLiteTensor* out, int num_elements) {
@@ -286,6 +330,20 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
286330
copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
287331
num_elements);
288332
break;
333+
case kTfLiteInt4:
334+
if (std::is_same<FromT, float>::value) {
335+
return castFloatToInt4(reinterpret_cast<const float*>(in), out,
336+
num_elements);
337+
} else {
338+
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
339+
}
340+
case kTfLiteInt2:
341+
if (std::is_same<FromT, float>::value) {
342+
return castFloatToInt2(reinterpret_cast<const float*>(in), out,
343+
num_elements);
344+
} else {
345+
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
346+
}
289347
default:
290348
// Unsupported type.
291349
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
@@ -334,6 +392,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
334392
TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Cast");
335393
}
336394
return castInt4ToFloat(context, input, output, num_elements);
395+
case kTfLiteInt2:
396+
if (output->type != kTfLiteFloat32) {
397+
TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Cast");
398+
}
399+
return castInt2ToFloat(context, input, output, num_elements);
337400
default:
338401
// Unsupported type.
339402
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Cast");

tflite/kernels/cast_test.cc

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@ limitations under the License.
1717
#include <algorithm>
1818
#include <complex>
1919
#include <limits>
20-
#include <random>
2120
#include <vector>
2221

2322
#include <gmock/gmock.h>
2423
#include <gtest/gtest.h>
24+
#include "absl/random/random.h"
2525
#include "absl/types/span.h"
2626
#include "Eigen/Core" // from @eigen_archive
2727
#include "tflite/c/common.h"
2828
#include "tflite/core/c/c_api_types.h"
2929
#include "tflite/kernels/cast_test_common.h"
30+
#include "tflite/kernels/internal/portable_tensor_utils.h"
31+
#include "tflite/kernels/kernel_util.h"
3032
#include "tflite/kernels/test_util.h"
3133
#include "tflite/schema/schema_generated.h"
3234

@@ -45,10 +47,10 @@ TEST(CastOpModel, CastInt4ToFloat) {
4547

4648
TEST(CastOpModel, CastInt4ToFloatLarge) {
4749
int num_elements = 40;
48-
std::random_device random_device;
49-
auto rng = std::mt19937(random_device());
50-
std::uniform_int_distribution<int8_t> i8dist(-8, 7);
51-
auto i8rng = [&] { return i8dist(rng); };
50+
absl::BitGen bitgen;
51+
auto i8rng = [&] {
52+
return absl::Uniform<int8_t>(absl::IntervalClosed, bitgen, -8, 7);
53+
};
5254
std::vector<int8_t> input(num_elements);
5355
std::generate(input.begin(), input.end(), i8rng);
5456
CastOpModel m({TensorType_INT4, {num_elements}},
@@ -60,6 +62,85 @@ TEST(CastOpModel, CastInt4ToFloatLarge) {
6062
}
6163
}
6264

65+
TEST(CastOpModel, CastInt2ToFloat) {
66+
CastOpModel m({TensorType_INT2, {2, 4}}, {TensorType_FLOAT32, {2, 4}});
67+
m.Set2BitInput({1, 0, -1, -2, 1, 0, -1, -2});
68+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
69+
EXPECT_THAT(m.ExtractVector<float>(m.output()),
70+
Pointwise(FloatingPointEq(),
71+
{1.f, 0.f, -1.f, -2.f, 1.f, 0.f, -1.f, -2.f}));
72+
}
73+
74+
TEST(CastOpModel, CastInt2ToFloatLarge) {
75+
int num_elements = 40;
76+
absl::BitGen bitgen;
77+
auto i2rng = [&] {
78+
return absl::Uniform<int8_t>(absl::IntervalClosed, bitgen, -2, 1);
79+
};
80+
std::vector<int8_t> input(num_elements);
81+
std::generate(input.begin(), input.end(), i2rng);
82+
CastOpModel m({TensorType_INT2, {num_elements}},
83+
{TensorType_FLOAT32, {num_elements}});
84+
m.Set2BitInput(input);
85+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
86+
for (int i = 0; i < input.size(); ++i) {
87+
EXPECT_EQ(m.ExtractVector<float>(m.output())[i], input[i]);
88+
}
89+
}
90+
91+
TEST(CastOpModel, CastFloatToInt4) {
92+
CastOpModel m({TensorType_FLOAT32, {2, 4}}, {TensorType_INT4, {2, 4}});
93+
m.PopulateTensor<float>(m.input(), {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, -8.f});
94+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
95+
TfLiteTensor* output = m.GetOutputTensor(0);
96+
int num_elements = NumElements(output);
97+
std::vector<int8_t> unpacked_output(num_elements);
98+
tensor_utils::UnpackPackedIntToInt8(
99+
reinterpret_cast<int8_t*>(output->data.data), num_elements,
100+
/*bit_width=*/4, unpacked_output.data());
101+
EXPECT_THAT(unpacked_output, ElementsAreArray({1, 2, 3, 4, 5, 6, 7, -8}));
102+
}
103+
104+
TEST(CastOpModel, CastFloatToInt4Clamp) {
105+
CastOpModel m({TensorType_FLOAT32, {1, 4}}, {TensorType_INT4, {1, 4}});
106+
m.PopulateTensor<float>(m.input(), {100.f, -100.f, 7.9f, -8.9f});
107+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
108+
TfLiteTensor* output = m.GetOutputTensor(0);
109+
int num_elements = NumElements(output);
110+
std::vector<int8_t> unpacked_output(num_elements);
111+
tensor_utils::UnpackPackedIntToInt8(
112+
reinterpret_cast<int8_t*>(output->data.data), num_elements,
113+
/*bit_width=*/4, unpacked_output.data());
114+
EXPECT_THAT(unpacked_output, ElementsAreArray({7, -8, 7, -8}));
115+
}
116+
117+
TEST(CastOpModel, CastFloatToInt2) {
118+
CastOpModel m({TensorType_FLOAT32, {2, 4}}, {TensorType_INT2, {2, 4}});
119+
m.PopulateTensor<float>(m.input(),
120+
{1.f, 0.f, -1.f, -2.f, 1.f, 0.f, -1.f, -2.f});
121+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
122+
TfLiteTensor* output = m.GetOutputTensor(0);
123+
int num_elements = NumElements(output);
124+
std::vector<int8_t> unpacked_output(num_elements);
125+
tensor_utils::UnpackPackedIntToInt8(
126+
reinterpret_cast<int8_t*>(output->data.data), num_elements,
127+
/*bit_width=*/2, unpacked_output.data());
128+
EXPECT_THAT(unpacked_output, ElementsAreArray({1, 0, -1, -2, 1, 0, -1, -2}));
129+
}
130+
131+
TEST(CastOpModel, CastFloatToInt2Clamp) {
132+
CastOpModel m({TensorType_FLOAT32, {1, 4}}, {TensorType_INT2, {1, 4}});
133+
m.PopulateTensor<float>(m.input(), {100.f, -100.f, 1.9f, -2.9f});
134+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
135+
TfLiteTensor* output = m.GetOutputTensor(0);
136+
int num_elements = NumElements(output);
137+
std::vector<int8_t> unpacked_output(num_elements);
138+
tensor_utils::UnpackPackedIntToInt8(
139+
reinterpret_cast<int8_t*>(output->data.data), num_elements,
140+
/*bit_width=*/2, unpacked_output.data());
141+
EXPECT_THAT(unpacked_output, ElementsAreArray({1, -2, 1, -2}));
142+
}
143+
63144
TEST(CastOpModel, CastFloatToUint8Infinity) {
64145
CastOpModel m({TensorType_FLOAT32, {2}}, {TensorType_UINT8, {2}});
65146
m.PopulateTensor<float>(m.input(), {std::numeric_limits<float>::infinity(),

tflite/kernels/cast_test_common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class CastOpModel : public SingleOpModel {
5959
PopulateTensor4bit(input_, 0, f.data(), f.data() + f.size());
6060
}
6161

62+
void Set2BitInput(absl::Span<const int8_t> data) {
63+
PopulateTensor2bit(input_, 0, data.data(), data.data() + data.size());
64+
}
65+
6266
int input() const { return input_; }
6367
int output() const { return output_; }
6468

tflite/kernels/register_ref.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
377377
/* max_version = */ 2);
378378
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
379379
/* min_version = */ 1,
380-
/* max_version = */ 7);
380+
/* max_version = */ 8);
381381
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(),
382382
/* min_version = */ 1,
383383
/* max_version = */ 6);

0 commit comments

Comments
 (0)