Skip to content

Commit 1bb6729

Browse files
majiddadashicopybara-github
authored andcommitted
Generalize PackInt8IntoDenseInt4 to support 2-bit and 4-bit packing.
Rename `PackInt8IntoDenseInt4` to `PackInt8IntoDenseInt` and add a `bit_width` parameter. Implement packing logic for both 2-bit and 4-bit integers. Update existing call sites in `quantize.cc` and `transpose.cc` to use the new function signature. Add new unit tests for both 2-bit and 4-bit packing. PiperOrigin-RevId: 818850218
1 parent 583de5f commit 1bb6729

File tree

5 files changed

+87
-27
lines changed

5 files changed

+87
-27
lines changed

tflite/kernels/internal/portable_tensor_utils.cc

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,40 @@ void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
143143
}
144144
}
145145

146-
void PackInt8IntoDenseInt4(const int8_t* src_buffer, int num_elements,
147-
int8_t* dst_buffer) {
148-
// num_elements means the number of elements regardless of packed or unpacked.
149-
// For example, 3 elements means both
150-
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
151-
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
152-
// 2) Unpacked: 3 int8's = 3 bytes.
153-
// stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
154-
for (int i = 0; i < num_elements - 1; i += 2) {
155-
dst_buffer[i / 2] = src_buffer[i] & 0x0F;
156-
dst_buffer[i / 2] |= src_buffer[i + 1] << 4;
157-
}
158-
auto packed_size = (num_elements + 1) / 2;
159-
160-
// Copy the final nibble if the buffer is odd-lengthed
161-
if (num_elements % 2 != 0) {
162-
dst_buffer[packed_size - 1] = src_buffer[num_elements - 1] & 0x0F;
146+
void PackInt8IntoDenseInt(const int8_t* src_buffer, int num_elements,
147+
int bit_width, int8_t* dst_buffer) {
148+
assert(bit_width == 2 || bit_width == 4);
149+
if (bit_width == 4) {
150+
// num_elements means the number of elements regardless of packed or
151+
// unpacked. For example, 3 elements means both
152+
// 1) Unpacked: 3 int8's = 3 bytes.
153+
// stored in src_buffer[0], src_buffer[1] and src_buffer[2] (j = 0..2)
154+
// 2) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
155+
// stored in dst_buffer[0] and dst_buffer[1] (i = 0..1)
156+
for (int i = 0; i < num_elements / 2; ++i) {
157+
dst_buffer[i] = (src_buffer[2 * i] & 0x0F) | (src_buffer[2 * i + 1] << 4);
158+
}
159+
// If the buffer size is odd, pack the final nibble.
160+
if (num_elements % 2 != 0) {
161+
dst_buffer[num_elements / 2] = src_buffer[num_elements - 1] & 0x0F;
162+
}
163+
} else if (bit_width == 2) {
164+
for (int i = 0; i < num_elements / 4; ++i) {
165+
dst_buffer[i] = (src_buffer[4 * i] & 0x03) |
166+
((src_buffer[4 * i + 1] & 0x03) << 2) |
167+
((src_buffer[4 * i + 2] & 0x03) << 4) |
168+
((src_buffer[4 * i + 3] & 0x03) << 6);
169+
}
170+
// Handle the remaining elements.
171+
int remaining_elements = num_elements % 4;
172+
if (remaining_elements > 0) {
173+
int8_t packed_val = 0;
174+
for (int i = 0; i < remaining_elements; ++i) {
175+
packed_val |= (src_buffer[num_elements - remaining_elements + i] & 0x03)
176+
<< (i * 2);
177+
}
178+
dst_buffer[num_elements / 4] = packed_val;
179+
}
163180
}
164181
}
165182

tflite/kernels/internal/portable_tensor_utils.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -635,20 +635,24 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
635635
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
636636
int bit_width, int8_t* dst_buffer);
637637

638-
// Pack `src_buffer` into a densely packed buffer of int4 values.
638+
// Pack `src_buffer` into a densely packed buffer of int2 or int4 values.
639639
// Parameters:
640-
// src_buffer : Buffer containing int4 values stored in int8 memory.
640+
// src_buffer : Buffer containing int2 or int4 values stored in int8
641+
// memory.
641642
// num_elements : Number of elements stored in the buffer. Note that this can
642643
// be smaller than the size of `src_buffer` by 1 if it's odd,
643644
// in which case the last nibble in `src_buffer` is ignored.
644645
// This should be equal to the size of `dst_buffer`.
646+
// bit_width : The bit width of the packed elements (either 2 or 4).
645647
// dst_buffer : Buffer to pack into. Should be allocated by the caller.
646648
// Size should be at least `num_elements`.
647649
// Notes:
648-
// For example, given `src_buffer = {0x02, 0x01, 0x04, 0x03}`, calling this
649-
// function will return `dst_buffer = {0x12, 0x34}`.
650-
void PackInt8IntoDenseInt4(const int8_t* src_buffer, int num_elements,
651-
int8_t* dst_buffer);
650+
// For 4-bit packing: e.g., given `src_buffer = {0x02, 0x01, 0x04, 0x03}`,
651+
// calling this function will return `dst_buffer = {0x12, 0x34}`.
652+
// For 2-bit packing: e.g., given `src_buffer = {0x00, 0x01, 0x00, 0x02}`,
653+
// calling this function will return `dst_buffer = {0x84}`.
654+
void PackInt8IntoDenseInt(const int8_t* src_buffer, int num_elements,
655+
int bit_width, int8_t* dst_buffer);
652656
} // namespace tensor_utils
653657

654658
} // namespace tflite

tflite/kernels/internal/tensor_utils_test.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,6 +2149,44 @@ TEST(uKernels, UnpackInt2OddLength) {
21492149
testing::Pointwise(testing::Eq(), expected_output));
21502150
}
21512151

2152+
TEST(uKernels, PackInt4Basic) {
2153+
const int8_t input[4] = {-8, 3, -2, -5};
2154+
const int8_t expected_output[2] = {0x38, static_cast<int8_t>(0xBE)};
2155+
int8_t actual_output[2];
2156+
PackInt8IntoDenseInt(input, 4, 4, actual_output);
2157+
EXPECT_THAT(actual_output,
2158+
testing::Pointwise(testing::Eq(), expected_output));
2159+
}
2160+
2161+
TEST(uKernels, PackInt4OddLength) {
2162+
// `num_elements` is odd, so the last element 0x4 should be ignored
2163+
const int8_t input[3] = {1, 2, 3};
2164+
const int8_t expected_output[2] = {0x21, 0x03};
2165+
int8_t actual_output[2];
2166+
PackInt8IntoDenseInt(input, 3, 4, actual_output);
2167+
EXPECT_THAT(actual_output,
2168+
testing::Pointwise(testing::Eq(), expected_output));
2169+
}
2170+
2171+
TEST(uKernels, PackInt2Basic) {
2172+
const int8_t input[4] = {0, -1, -2, 1};
2173+
const int8_t expected_output[1] = {0x6C};
2174+
int8_t actual_output[1];
2175+
PackInt8IntoDenseInt(input, 4, 2, actual_output);
2176+
EXPECT_THAT(actual_output,
2177+
testing::Pointwise(testing::Eq(), expected_output));
2178+
}
2179+
2180+
TEST(uKernels, PackInt2OddLength) {
2181+
// `num_elements` is odd
2182+
const int8_t input[3] = {0, -2, 1};
2183+
const int8_t expected_output[1] = {0x18};
2184+
int8_t actual_output[1];
2185+
PackInt8IntoDenseInt(input, 3, 2, actual_output);
2186+
EXPECT_THAT(actual_output,
2187+
testing::Pointwise(testing::Eq(), expected_output));
2188+
}
2189+
21522190
} // namespace tensor_utils
21532191
} // namespace tflite
21542192

tflite/kernels/quantize.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "tflite/core/c/common.h"
2222
#include "tflite/kernels/internal/optimized/optimized_ops.h"
23+
#include "tflite/kernels/internal/portable_tensor_utils.h"
2324
#include "tflite/kernels/internal/quantization_util.h"
2425
#include "tflite/kernels/internal/reference/reference_ops.h"
2526
#include "tflite/kernels/internal/reference/requantize.h"
@@ -109,8 +110,8 @@ void AffineQuantizeToInt4(const tflite::QuantizationParams& op_params,
109110
int32_t clamped = std::min(std::max(unclamped, min_val), max_val);
110111
quantized_buffer[i] = clamped;
111112
}
112-
tensor_utils::PackInt8IntoDenseInt4(quantized_buffer.data(), flat_size,
113-
output_data);
113+
tensor_utils::PackInt8IntoDenseInt(quantized_buffer.data(), flat_size,
114+
/*bit_width=*/4, output_data);
114115
}
115116

116117
void ReportError(TfLiteContext* context, TfLiteType input_type,

tflite/kernels/transpose.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
142142
params, GetTensorShape(op_context.input), unpacked_input_data.get(),
143143
GetTensorShape(op_context.output), unpacked_output_data.get());
144144
// Pack the output back to int4.
145-
tflite::tensor_utils::PackInt8IntoDenseInt4(
145+
tflite::tensor_utils::PackInt8IntoDenseInt(
146146
unpacked_output_data.get(),
147147
GetTensorShape(op_context.input).FlatSize(),
148-
GetTensorData<int8_t>(op_context.output));
148+
/*bit_width=*/4, GetTensorData<int8_t>(op_context.output));
149149
break;
150150
}
151151
case kTfLiteInt16:

0 commit comments

Comments
 (0)