@@ -30,6 +30,7 @@ limitations under the License.
3030
3131#include < gmock/gmock.h>
3232#include < gtest/gtest.h>
33+ #include " absl/log/absl_check.h"
3334#include " tflite/core/interpreter.h"
3435#include " tflite/kernels/test_util.h"
3536#include " tflite/schema/schema_generated.h"
@@ -159,22 +160,34 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
159160 std::vector<int64_t > per_channel_quantization_offsets (
160161 per_channel_quantization_scales.size (), 0 );
161162 weights_ = AddInput ({filter_type,
162- {units_, input_size_},
163- 0 ,
164- 0 ,
165- 0 ,
166- 0 ,
167- true ,
163+ /* shape= */ {units_, input_size_},
164+ /* min= */ 0 ,
165+ /* max= */ 0 ,
166+ /* scale= */ 0 ,
167+ /* zero_point= */ 0 ,
168+ /* per_channel_quantization= */ true ,
168169 per_channel_quantization_scales,
169170 per_channel_quantization_offsets,
170- 0 });
171+ /* channel_index= */ 0 });
171172 } else {
172173 // per-tensor
173174 float min = input.min ;
174175 float max = input.max ;
175- if (filter_type == TensorType_INT4 || filter_type == TensorType_INT8) {
176- min = filter_type == TensorType_INT4 ? -7 .f : -63 .5f ;
177- max = filter_type == TensorType_INT4 ? 7 .f : 64 .f ;
176+ switch (filter_type) {
177+ case TensorType_INT4:
178+ min = -7 .f ;
179+ max = 7 .f ;
180+ break ;
181+ case TensorType_INT2:
182+ min = -2 .f ;
183+ max = 2 .f ;
184+ break ;
185+ case TensorType_INT8:
186+ min = -63 .5f ;
187+ max = 64 .f ;
188+ break ;
189+ default :
190+ break ;
178191 }
179192 weights_ = AddInput ({filter_type, {units_, input_size_}, min, max});
180193 }
@@ -292,6 +305,13 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
292305 QuantizeAndPopulate4bit (weights_, data);
293306 }
294307
308+ void SetWeights2bit (const std::vector<float >& data) {
309+ TfLiteTensor* t = interpreter_->tensor (weights_);
310+ std::vector<int8_t > u =
311+ Quantize<int8_t >(data, t->params .scale , t->params .zero_point , t->type );
312+ PopulateTensor2bit (weights_, 0 , u.data (), u.data () + u.size ());
313+ }
314+
295315 template <typename T>
296316 void ShuffleAndSetWeights (const std::vector<float >& data, int input_depth,
297317 int output_depth) {
@@ -372,6 +392,12 @@ class PerChannelQuantizedFullyConnectedOpModel
372392 PerChannelSymmetricQuantizeAndPopulate (weights_, data);
373393 }
374394
395+ void SetWeights2bit (const std::vector<float >& data) {
396+ // 2 bit logic handled in PerChannelSymmetricQuantizeAndPopulate.
397+ ABSL_CHECK_EQ (interpreter_->tensor (weights_)->type , kTfLiteInt2 );
398+ PerChannelSymmetricQuantizeAndPopulate (weights_, data);
399+ }
400+
375401 template <typename T>
376402 void SetInput (const std::vector<float >& data) {
377403 QuantizeAndPopulate<T>(input_, data);
@@ -734,6 +760,38 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt4) {
734760 EXPECT_THAT (m.GetOutput <int8_t >(), ElementsAre (103 , 104 , 105 , 97 , 98 , 99 ));
735761}
736762
763+ TEST_P (QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt2) {
764+ QuantizedFullyConnectedOpModel m (
765+ GetRegistration (), /* units=*/ 3 , /* batches*/ 2 ,
766+ /* input=*/ {TensorType_INT8, {2 , 10 }, -63.5 , 64 },
767+ /* output=*/ {TensorType_INT8, {}, -127 , 128 }, TensorType_INT32, false ,
768+ false , ActivationFunctionType_RELU,
769+ FullyConnectedOptionsWeightsFormat_DEFAULT, -1 , TensorType_INT2);
770+
771+ m.SetWeights2bit ({
772+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 0
773+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 1
774+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 2
775+ });
776+ m.SetBias ({1 ., 2 ., 3 .});
777+
778+ m.SetInput <int8_t >({
779+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , -9 , -10 , // b = 0
780+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , -8 , 9 , -10 , // b = 1
781+ });
782+
783+ // The quantization parameters for the model.
784+ // input s, zp: 0.5, -1
785+ // filter s, zp: 0.5, 0
786+ // output s, zp: 1, -1
787+
788+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
789+ EXPECT_THAT (m.GetDequantizedOutput <int8_t >(),
790+ testing::Pointwise (testing::FloatEq (),
791+ {26.0 , 27.0 , 28.0 , 8.0 , 9.0 , 10.0 }));
792+ EXPECT_THAT (m.GetOutput <int8_t >(), ElementsAre (25 , 26 , 27 , 7 , 8 , 9 ));
793+ }
794+
737795TEST_P (QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) {
738796 QuantizedFullyConnectedOpModel m (
739797 GetRegistration (), /* units=*/ 3 , /* batches*/ 2 ,
@@ -863,6 +921,34 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt4) {
863921 EXPECT_THAT (m.GetOutput <int8_t >(), ElementsAre (103 , 104 , 105 , 97 , 98 , 99 ));
864922}
865923
924+ TEST_P (QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt2) {
925+ PerChannelQuantizedFullyConnectedOpModel m (
926+ GetRegistration (), /* units=*/ 3 , /* batches*/ 2 ,
927+ /* input=*/ {TensorType_INT8, {2 , 10 }, -63.5 , 64 },
928+ /* per_channel_quantization_scales=*/ {1.0 , 1.0 , 1.0 },
929+ /* output=*/ {TensorType_INT8, {}, -127 , 128 },
930+ /* bias_type=*/ TensorType_INT32, false , false , ActivationFunctionType_RELU,
931+ FullyConnectedOptionsWeightsFormat_DEFAULT, -1 , TensorType_INT2);
932+
933+ m.SetWeights2bit ({
934+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 0
935+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 1
936+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 2
937+ });
938+ m.SetBias ({1 , 2 , 3 });
939+
940+ m.SetInput <int8_t >({
941+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , -9 , -10 , // b = 0
942+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , -8 , 9 , -10 , // b = 1
943+ });
944+
945+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
946+
947+ EXPECT_THAT (m.GetDequantizedOutput <int8_t >(),
948+ ElementsAreArray (ArrayFloatNear ({26 , 27 , 28 , 8 , 9 , 10 })));
949+ EXPECT_THAT (m.GetOutput <int8_t >(), ElementsAre (25 , 26 , 27 , 7 , 8 , 9 ));
950+ }
951+
866952TEST_P (QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16NoBias) {
867953 const float scale = 128.0 / 65536 ;
868954 QuantizedFullyConnectedOpModel m (
@@ -1018,6 +1104,37 @@ TEST_P(QuantizedFullyConnectedOpTest,
10181104 ElementsAre (1536 , 2048 , 2560 , 11776 , 12288 , 12800 ));
10191105}
10201106
1107+ TEST_P (QuantizedFullyConnectedOpTest,
1108+ SimpleTestPerChannelQuantizedInt16Bias32Weight2) {
1109+ const float scale = 128.0 / 65536 ;
1110+ PerChannelQuantizedFullyConnectedOpModel m (
1111+ GetRegistration (), /* units=*/ 3 , /* batches*/ 2 ,
1112+ /* input=*/ {TensorType_INT16, {2 , 10 }, 0 , 0 , scale, 0 },
1113+ /* per_channel_quantization_scales=*/ {1.0 , 1.0 , 1.0 },
1114+ /* output=*/ {TensorType_INT16, {}, 0 , 0 , scale, 0 },
1115+ /* bias_type=*/ TensorType_INT32, false , false , ActivationFunctionType_RELU,
1116+ FullyConnectedOptionsWeightsFormat_DEFAULT, -1 , TensorType_INT2);
1117+
1118+ m.SetWeights2bit ({
1119+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 0
1120+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 1
1121+ 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , -1 , 0 , // u = 2
1122+ });
1123+ m.SetBias ({1 , 2 , 3 });
1124+
1125+ m.SetInput <int16_t >({
1126+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , -9 , -10 , // b = 0
1127+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , -8 , 9 , -10 , // b = 1
1128+ });
1129+
1130+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
1131+
1132+ EXPECT_THAT (m.GetDequantizedOutput <int16_t >(),
1133+ ElementsAreArray (ArrayFloatNear ({26 , 27 , 28 , 8 , 9 , 10 })));
1134+ EXPECT_THAT (m.GetOutput <int16_t >(),
1135+ ElementsAre (13312 , 13824 , 14336 , 4096 , 4608 , 5120 ));
1136+ }
1137+
10211138TEST_P (QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias64) {
10221139 const float scale = 128.0 / 65536 ;
10231140 QuantizedFullyConnectedOpModel m (
0 commit comments