@@ -17,16 +17,18 @@ limitations under the License.
1717#include < algorithm>
1818#include < complex>
1919#include < limits>
20- #include < random>
2120#include < vector>
2221
2322#include < gmock/gmock.h>
2423#include < gtest/gtest.h>
24+ #include " absl/random/random.h"
2525#include " absl/types/span.h"
2626#include " Eigen/Core" // from @eigen_archive
2727#include " tflite/c/common.h"
2828#include " tflite/core/c/c_api_types.h"
2929#include " tflite/kernels/cast_test_common.h"
30+ #include " tflite/kernels/internal/portable_tensor_utils.h"
31+ #include " tflite/kernels/kernel_util.h"
3032#include " tflite/kernels/test_util.h"
3133#include " tflite/schema/schema_generated.h"
3234
@@ -45,10 +47,10 @@ TEST(CastOpModel, CastInt4ToFloat) {
4547
4648TEST (CastOpModel, CastInt4ToFloatLarge) {
4749 int num_elements = 40 ;
48- std::random_device random_device ;
49- auto rng = std::mt19937 ( random_device ());
50- std::uniform_int_distribution <int8_t > i8dist ( -8 , 7 );
51- auto i8rng = [&] { return i8dist (rng); };
50+ absl::BitGen bitgen ;
51+ auto i8rng = [&] {
52+ return absl::Uniform <int8_t >(absl::IntervalClosed, bitgen, -8 , 7 );
53+ };
5254 std::vector<int8_t > input (num_elements);
5355 std::generate (input.begin (), input.end (), i8rng);
5456 CastOpModel m ({TensorType_INT4, {num_elements}},
@@ -60,6 +62,85 @@ TEST(CastOpModel, CastInt4ToFloatLarge) {
6062 }
6163}
6264
65+ TEST (CastOpModel, CastInt2ToFloat) {
66+ CastOpModel m ({TensorType_INT2, {2 , 4 }}, {TensorType_FLOAT32, {2 , 4 }});
67+ m.Set2BitInput ({1 , 0 , -1 , -2 , 1 , 0 , -1 , -2 });
68+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
69+ EXPECT_THAT (m.ExtractVector <float >(m.output ()),
70+ Pointwise (FloatingPointEq (),
71+ {1 .f , 0 .f , -1 .f , -2 .f , 1 .f , 0 .f , -1 .f , -2 .f }));
72+ }
73+
74+ TEST (CastOpModel, CastInt2ToFloatLarge) {
75+ int num_elements = 40 ;
76+ absl::BitGen bitgen;
77+ auto i2rng = [&] {
78+ return absl::Uniform<int8_t >(absl::IntervalClosed, bitgen, -2 , 1 );
79+ };
80+ std::vector<int8_t > input (num_elements);
81+ std::generate (input.begin (), input.end (), i2rng);
82+ CastOpModel m ({TensorType_INT2, {num_elements}},
83+ {TensorType_FLOAT32, {num_elements}});
84+ m.Set2BitInput (input);
85+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
86+ for (int i = 0 ; i < input.size (); ++i) {
87+ EXPECT_EQ (m.ExtractVector <float >(m.output ())[i], input[i]);
88+ }
89+ }
90+
91+ TEST (CastOpModel, CastFloatToInt4) {
92+ CastOpModel m ({TensorType_FLOAT32, {2 , 4 }}, {TensorType_INT4, {2 , 4 }});
93+ m.PopulateTensor <float >(m.input (), {1 .f , 2 .f , 3 .f , 4 .f , 5 .f , 6 .f , 7 .f , -8 .f });
94+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
95+ TfLiteTensor* output = m.GetOutputTensor (0 );
96+ int num_elements = NumElements (output);
97+ std::vector<int8_t > unpacked_output (num_elements);
98+ tensor_utils::UnpackPackedIntToInt8 (
99+ reinterpret_cast <int8_t *>(output->data .data ), num_elements,
100+ /* bit_width=*/ 4 , unpacked_output.data ());
101+ EXPECT_THAT (unpacked_output, ElementsAreArray ({1 , 2 , 3 , 4 , 5 , 6 , 7 , -8 }));
102+ }
103+
104+ TEST (CastOpModel, CastFloatToInt4Clamp) {
105+ CastOpModel m ({TensorType_FLOAT32, {1 , 4 }}, {TensorType_INT4, {1 , 4 }});
106+ m.PopulateTensor <float >(m.input (), {100 .f , -100 .f , 7 .9f , -8 .9f });
107+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
108+ TfLiteTensor* output = m.GetOutputTensor (0 );
109+ int num_elements = NumElements (output);
110+ std::vector<int8_t > unpacked_output (num_elements);
111+ tensor_utils::UnpackPackedIntToInt8 (
112+ reinterpret_cast <int8_t *>(output->data .data ), num_elements,
113+ /* bit_width=*/ 4 , unpacked_output.data ());
114+ EXPECT_THAT (unpacked_output, ElementsAreArray ({7 , -8 , 7 , -8 }));
115+ }
116+
117+ TEST (CastOpModel, CastFloatToInt2) {
118+ CastOpModel m ({TensorType_FLOAT32, {2 , 4 }}, {TensorType_INT2, {2 , 4 }});
119+ m.PopulateTensor <float >(m.input (),
120+ {1 .f , 0 .f , -1 .f , -2 .f , 1 .f , 0 .f , -1 .f , -2 .f });
121+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
122+ TfLiteTensor* output = m.GetOutputTensor (0 );
123+ int num_elements = NumElements (output);
124+ std::vector<int8_t > unpacked_output (num_elements);
125+ tensor_utils::UnpackPackedIntToInt8 (
126+ reinterpret_cast <int8_t *>(output->data .data ), num_elements,
127+ /* bit_width=*/ 2 , unpacked_output.data ());
128+ EXPECT_THAT (unpacked_output, ElementsAreArray ({1 , 0 , -1 , -2 , 1 , 0 , -1 , -2 }));
129+ }
130+
131+ TEST (CastOpModel, CastFloatToInt2Clamp) {
132+ CastOpModel m ({TensorType_FLOAT32, {1 , 4 }}, {TensorType_INT2, {1 , 4 }});
133+ m.PopulateTensor <float >(m.input (), {100 .f , -100 .f , 1 .9f , -2 .9f });
134+ ASSERT_EQ (m.Invoke (), kTfLiteOk );
135+ TfLiteTensor* output = m.GetOutputTensor (0 );
136+ int num_elements = NumElements (output);
137+ std::vector<int8_t > unpacked_output (num_elements);
138+ tensor_utils::UnpackPackedIntToInt8 (
139+ reinterpret_cast <int8_t *>(output->data .data ), num_elements,
140+ /* bit_width=*/ 2 , unpacked_output.data ());
141+ EXPECT_THAT (unpacked_output, ElementsAreArray ({1 , -2 , 1 , -2 }));
142+ }
143+
63144TEST (CastOpModel, CastFloatToUint8Infinity) {
64145 CastOpModel m ({TensorType_FLOAT32, {2 }}, {TensorType_UINT8, {2 }});
65146 m.PopulateTensor <float >(m.input (), {std::numeric_limits<float >::infinity (),
0 commit comments