Skip to content

Commit 7a5bccc

Browse files
majiddadashicopybara-github
authored andcommitted
Add support for kTfLiteInt2 (srq) in tfl.fully_connected.
PiperOrigin-RevId: 822405584
1 parent 8afe70b commit 7a5bccc

File tree

6 files changed

+181
-24
lines changed

6 files changed

+181
-24
lines changed

tflite/core/kernels/register.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
8383
Register_EMBEDDING_LOOKUP_SPARSE());
8484
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
8585
/* min_version = */ 1,
86-
/* max_version = */ 13);
86+
/* max_version = */ 14);
8787
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
8888
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
8989
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),

tflite/kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,6 +2225,7 @@ cc_test(
22252225
"//tflite/core/api",
22262226
"//tflite/kernels/internal:tensor_utils",
22272227
"//tflite/schema:schema_fbs",
2228+
"@com_google_absl//absl/log:absl_check",
22282229
"@com_google_absl//absl/memory",
22292230
"@com_google_googletest//:gtest",
22302231
"@flatbuffers",

tflite/kernels/fully_connected.cc

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ inline TfLiteStatus CheckTypes(TfLiteContext* context,
186186
TfLiteFullyConnectedParams* params) {
187187
const bool is_quantized =
188188
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8) ||
189-
(filter->type == kTfLiteInt4));
189+
(filter->type == kTfLiteInt4) || (filter->type == kTfLiteInt2));
190190
const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
191191
const bool is_shuffled =
192192
is_quantized && (params->weights_format ==
@@ -448,7 +448,8 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node,
448448
TF_LITE_ENSURE(context,
449449
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
450450
TF_LITE_ENSURE(context, (filter->type == kTfLiteInt8 ||
451-
filter->type == kTfLiteInt4));
451+
filter->type == kTfLiteInt4 ||
452+
filter->type == kTfLiteInt2));
452453
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
453454
per_channel_quantization_size);
454455
TF_LITE_ENSURE_EQ(
@@ -654,7 +655,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
654655
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
655656
const bool is_quantized =
656657
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8) ||
657-
(filter->type == kTfLiteInt4));
658+
(filter->type == kTfLiteInt4) || (filter->type == kTfLiteInt2));
658659
const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
659660
const bool is_pie = kernel_type == kLegacyPie;
660661

@@ -666,7 +667,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
666667
params->activation == kTfLiteActReluN1To1 ||
667668
params->activation == kTfLiteActRelu6);
668669
}
669-
if (filter->type == kTfLiteInt4) {
670+
if (filter->type == kTfLiteInt4 || filter->type == kTfLiteInt2) {
670671
TF_LITE_ENSURE_MSG(
671672
context,
672673
kTfLiteOk == VerifyQuantizationZeroPoint(filter, /*expected_value=*/0),
@@ -1420,6 +1421,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
14201421
case kTfLiteUInt8:
14211422
if (kernel_type == kReference) {
14221423
TF_LITE_ENSURE(context, filter->type != kTfLiteInt4);
1424+
TF_LITE_ENSURE(context, filter->type != kTfLiteInt2);
14231425
reference_ops::FullyConnected(
14241426
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
14251427
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
@@ -1456,8 +1458,10 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
14561458
"Invalid quantized and sparse fully-connected format.");
14571459
return kTfLiteError;
14581460
}
1459-
// Int4 support for sparse filter tensor is currently not supported
1461+
// Int4/Int2 support for sparse filter tensor is currently not
1462+
// supported
14601463
TF_LITE_ENSURE(context, filter->type != kTfLiteInt4);
1464+
TF_LITE_ENSURE(context, filter->type != kTfLiteInt2);
14611465
if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
14621466
sparsity.dim_metadata[2].dense_size == 16) {
14631467
// Block sparse with block size of 1x16.
@@ -1485,6 +1489,14 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
14851489
GetTensorShape(filter).FlatSize(), /*bit_width=*/4,
14861490
unpacked_filter_data.get());
14871491
filter_data = unpacked_filter_data.get();
1492+
} else if (filter->type == kTfLiteInt2) {
1493+
const size_t bytes_unpacked = filter->bytes * 4;
1494+
unpacked_filter_data = std::make_unique<int8_t[]>(bytes_unpacked);
1495+
tflite::tensor_utils::UnpackPackedIntToInt8(
1496+
GetTensorData<int8_t>(filter),
1497+
GetTensorShape(filter).FlatSize(), /*bit_width=*/2,
1498+
unpacked_filter_data.get());
1499+
filter_data = unpacked_filter_data.get();
14881500
} else {
14891501
filter_data = GetTensorData<int8_t>(filter);
14901502
}
@@ -1514,6 +1526,14 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
15141526
GetTensorShape(filter).FlatSize(), /*bit_width=*/4,
15151527
unpacked_filter_data.get());
15161528
filter_data = unpacked_filter_data.get();
1529+
} else if (filter->type == kTfLiteInt2) {
1530+
const size_t bytes_unpacked = filter->bytes * 4;
1531+
unpacked_filter_data = std::make_unique<int8_t[]>(bytes_unpacked);
1532+
tflite::tensor_utils::UnpackPackedIntToInt8(
1533+
GetTensorData<int8_t>(filter),
1534+
GetTensorShape(filter).FlatSize(), /*bit_width=*/2,
1535+
unpacked_filter_data.get());
1536+
filter_data = unpacked_filter_data.get();
15171537
} else {
15181538
filter_data = GetTensorData<int8_t>(filter);
15191539
}
@@ -1762,14 +1782,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
17621782
return kTfLiteError;
17631783
}
17641784
case kTfLiteInt8:
1765-
if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
1766-
return EvalQuantized<kernel_type>(context, node, params, data, input,
1767-
filter, bias, output);
1768-
} else {
1769-
TF_LITE_KERNEL_LOG(context, "Unhandled fully-connected weights format");
1770-
return kTfLiteError;
1771-
}
17721785
case kTfLiteInt4:
1786+
case kTfLiteInt2:
17731787
if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
17741788
return EvalQuantized<kernel_type>(context, node, params, data, input,
17751789
filter, bias, output);

tflite/kernels/fully_connected_test.cc

Lines changed: 127 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030

3131
#include <gmock/gmock.h>
3232
#include <gtest/gtest.h>
33+
#include "absl/log/absl_check.h"
3334
#include "tflite/core/interpreter.h"
3435
#include "tflite/kernels/test_util.h"
3536
#include "tflite/schema/schema_generated.h"
@@ -159,22 +160,34 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
159160
std::vector<int64_t> per_channel_quantization_offsets(
160161
per_channel_quantization_scales.size(), 0);
161162
weights_ = AddInput({filter_type,
162-
{units_, input_size_},
163-
0,
164-
0,
165-
0,
166-
0,
167-
true,
163+
/*shape=*/{units_, input_size_},
164+
/*min=*/0,
165+
/*max=*/0,
166+
/*scale=*/0,
167+
/*zero_point=*/0,
168+
/*per_channel_quantization=*/true,
168169
per_channel_quantization_scales,
169170
per_channel_quantization_offsets,
170-
0});
171+
/*channel_index=*/0});
171172
} else {
172173
// per-tensor
173174
float min = input.min;
174175
float max = input.max;
175-
if (filter_type == TensorType_INT4 || filter_type == TensorType_INT8) {
176-
min = filter_type == TensorType_INT4 ? -7.f : -63.5f;
177-
max = filter_type == TensorType_INT4 ? 7.f : 64.f;
176+
switch (filter_type) {
177+
case TensorType_INT4:
178+
min = -7.f;
179+
max = 7.f;
180+
break;
181+
case TensorType_INT2:
182+
min = -2.f;
183+
max = 2.f;
184+
break;
185+
case TensorType_INT8:
186+
min = -63.5f;
187+
max = 64.f;
188+
break;
189+
default:
190+
break;
178191
}
179192
weights_ = AddInput({filter_type, {units_, input_size_}, min, max});
180193
}
@@ -292,6 +305,13 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
292305
QuantizeAndPopulate4bit(weights_, data);
293306
}
294307

308+
void SetWeights2bit(const std::vector<float>& data) {
309+
TfLiteTensor* t = interpreter_->tensor(weights_);
310+
std::vector<int8_t> u =
311+
Quantize<int8_t>(data, t->params.scale, t->params.zero_point, t->type);
312+
PopulateTensor2bit(weights_, 0, u.data(), u.data() + u.size());
313+
}
314+
295315
template <typename T>
296316
void ShuffleAndSetWeights(const std::vector<float>& data, int input_depth,
297317
int output_depth) {
@@ -372,6 +392,12 @@ class PerChannelQuantizedFullyConnectedOpModel
372392
PerChannelSymmetricQuantizeAndPopulate(weights_, data);
373393
}
374394

395+
void SetWeights2bit(const std::vector<float>& data) {
396+
// 2 bit logic handled in PerChannelSymmetricQuantizeAndPopulate.
397+
ABSL_CHECK_EQ(interpreter_->tensor(weights_)->type, kTfLiteInt2);
398+
PerChannelSymmetricQuantizeAndPopulate(weights_, data);
399+
}
400+
375401
template <typename T>
376402
void SetInput(const std::vector<float>& data) {
377403
QuantizeAndPopulate<T>(input_, data);
@@ -734,6 +760,38 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt4) {
734760
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(103, 104, 105, 97, 98, 99));
735761
}
736762

763+
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt2) {
764+
QuantizedFullyConnectedOpModel m(
765+
GetRegistration(), /*units=*/3, /*batches*/ 2,
766+
/*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
767+
/*output=*/{TensorType_INT8, {}, -127, 128}, TensorType_INT32, false,
768+
false, ActivationFunctionType_RELU,
769+
FullyConnectedOptionsWeightsFormat_DEFAULT, -1, TensorType_INT2);
770+
771+
m.SetWeights2bit({
772+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 0
773+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 1
774+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 2
775+
});
776+
m.SetBias({1., 2., 3.});
777+
778+
m.SetInput<int8_t>({
779+
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
780+
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
781+
});
782+
783+
// The quantization parameters for the model.
784+
// input s, zp: 0.5, -1
785+
// filter s, zp: 0.5, 0
786+
// output s, zp: 1, -1
787+
788+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
789+
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
790+
testing::Pointwise(testing::FloatEq(),
791+
{26.0, 27.0, 28.0, 8.0, 9.0, 10.0}));
792+
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(25, 26, 27, 7, 8, 9));
793+
}
794+
737795
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) {
738796
QuantizedFullyConnectedOpModel m(
739797
GetRegistration(), /*units=*/3, /*batches*/ 2,
@@ -863,6 +921,34 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt4) {
863921
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(103, 104, 105, 97, 98, 99));
864922
}
865923

924+
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt2) {
925+
PerChannelQuantizedFullyConnectedOpModel m(
926+
GetRegistration(), /*units=*/3, /*batches*/ 2,
927+
/*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
928+
/*per_channel_quantization_scales=*/{1.0, 1.0, 1.0},
929+
/*output=*/{TensorType_INT8, {}, -127, 128},
930+
/*bias_type=*/TensorType_INT32, false, false, ActivationFunctionType_RELU,
931+
FullyConnectedOptionsWeightsFormat_DEFAULT, -1, TensorType_INT2);
932+
933+
m.SetWeights2bit({
934+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 0
935+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 1
936+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 2
937+
});
938+
m.SetBias({1, 2, 3});
939+
940+
m.SetInput<int8_t>({
941+
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
942+
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
943+
});
944+
945+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
946+
947+
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
948+
ElementsAreArray(ArrayFloatNear({26, 27, 28, 8, 9, 10})));
949+
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(25, 26, 27, 7, 8, 9));
950+
}
951+
866952
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16NoBias) {
867953
const float scale = 128.0 / 65536;
868954
QuantizedFullyConnectedOpModel m(
@@ -1018,6 +1104,37 @@ TEST_P(QuantizedFullyConnectedOpTest,
10181104
ElementsAre(1536, 2048, 2560, 11776, 12288, 12800));
10191105
}
10201106

1107+
TEST_P(QuantizedFullyConnectedOpTest,
1108+
SimpleTestPerChannelQuantizedInt16Bias32Weight2) {
1109+
const float scale = 128.0 / 65536;
1110+
PerChannelQuantizedFullyConnectedOpModel m(
1111+
GetRegistration(), /*units=*/3, /*batches*/ 2,
1112+
/*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
1113+
/*per_channel_quantization_scales=*/{1.0, 1.0, 1.0},
1114+
/*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
1115+
/*bias_type=*/TensorType_INT32, false, false, ActivationFunctionType_RELU,
1116+
FullyConnectedOptionsWeightsFormat_DEFAULT, -1, TensorType_INT2);
1117+
1118+
m.SetWeights2bit({
1119+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 0
1120+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 1
1121+
1, 0, 1, 0, 1, 0, 1, 0, -1, 0, // u = 2
1122+
});
1123+
m.SetBias({1, 2, 3});
1124+
1125+
m.SetInput<int16_t>({
1126+
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1127+
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
1128+
});
1129+
1130+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
1131+
1132+
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1133+
ElementsAreArray(ArrayFloatNear({26, 27, 28, 8, 9, 10})));
1134+
EXPECT_THAT(m.GetOutput<int16_t>(),
1135+
ElementsAre(13312, 13824, 14336, 4096, 4608, 5120));
1136+
}
1137+
10211138
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias64) {
10221139
const float scale = 128.0 / 65536;
10231140
QuantizedFullyConnectedOpModel m(

tflite/kernels/register_ref.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
280280
Register_EMBEDDING_LOOKUP_SPARSE());
281281
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED_REF(),
282282
/* min_version */ 1,
283-
/* max_version */ 11);
283+
/* max_version */ 14);
284284
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
285285
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
286286
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX_REF(),

tflite/kernels/test_util.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ inline std::vector<T> Quantize(const std::vector<float>& data, float scale,
109109
if (type == kTfLiteInt4) {
110110
min = -7;
111111
max = 7;
112+
} else if (type == kTfLiteInt2) {
113+
min = -2;
114+
max = 1;
112115
}
113116

114117
q.reserve(data.size());
@@ -570,6 +573,15 @@ class SingleOpModel {
570573
quantized_output.data() + quantized_output.size());
571574
}
572575

576+
void QuantizeAndPopulate2bit(int index, const std::vector<float>& data) {
577+
TfLiteTensor* t = interpreter_->tensor(index);
578+
t->type = kTfLiteInt2;
579+
std::vector<int8_t> quantized_output =
580+
Quantize<int8_t>(data, t->params.scale, t->params.zero_point, t->type);
581+
PopulateTensor2bit(index, /*offset=*/0, quantized_output.data(),
582+
quantized_output.data() + quantized_output.size());
583+
}
584+
573585
void SymmetricQuantizeAndPopulate(int index, const std::vector<float>& data) {
574586
std::vector<int8_t> q = QuantizeTensor(index, data);
575587
PopulateTensor(index, /*offset=*/0, reinterpret_cast<uint8_t*>(q.data()),
@@ -583,6 +595,10 @@ class SingleOpModel {
583595
std::vector<int8_t> q = Quantize<int8_t>(data, t->params.scale,
584596
t->params.zero_point, t->type);
585597
PopulateTensor4bit(index, /*offset=*/0, q.data(), q.data() + q.size());
598+
} else if (t->type == kTfLiteInt2) {
599+
std::vector<int8_t> q = Quantize<int8_t>(data, t->params.scale,
600+
t->params.zero_point, t->type);
601+
PopulateTensor2bit(index, /*offset=*/0, q.data(), q.data() + q.size());
586602
} else {
587603
std::vector<int8_t> q = QuantizeTensor(index, data);
588604
PopulateTensor(index, /*offset=*/0, q.data(), q.data() + q.size());
@@ -663,6 +679,9 @@ class SingleOpModel {
663679
PopulateTensor4bit(index, /*offset=*/0, quantized_output.data(),
664680
quantized_output.data() + quantized_output.size());
665681

682+
} else if (t->type == kTfLiteInt2) {
683+
PopulateTensor2bit(index, /*offset=*/0, quantized_output.data(),
684+
quantized_output.data() + quantized_output.size());
666685
} else {
667686
PopulateTensor(index, /*offset=*/0, quantized_output.data(),
668687
quantized_output.data() + quantized_output.size());
@@ -888,6 +907,9 @@ class SingleOpModel {
888907
} else if (t.type == TensorType_INT4) {
889908
std::tie(t.scale, t.zero_point) =
890909
QuantizationParams<int8_t>(t.min, t.max, kTfLiteInt4);
910+
} else if (t.type == TensorType_INT2) {
911+
std::tie(t.scale, t.zero_point) =
912+
QuantizationParams<int8_t>(t.min, t.max, kTfLiteInt2);
891913
} else {
892914
ABSL_LOG(FATAL) << "No support for the requested quantized type";
893915
}
@@ -940,6 +962,9 @@ class SingleOpModel {
940962
if (type == kTfLiteInt4) {
941963
qmin = -7;
942964
qmax = 7;
965+
} else if (type == kTfLiteInt2) {
966+
qmin = -2;
967+
qmax = 2;
943968
} else {
944969
qmin = std::numeric_limits<T>::min();
945970
qmax = std::numeric_limits<T>::max();

0 commit comments

Comments
 (0)