Skip to content

Commit d13d4bf

Browse files
majiddadashicopybara-github
authored andcommitted
Add support for kTfLiteInt2 to Dequantize kernels.
This change enables the Dequantize and PerChannelDequantize operations to handle 2-bit integer inputs (`kTfLiteInt2`). It includes logic to unpack the packed 2-bit integers into int8_t before performing the dequantization and adds new test cases for both per-tensor and per-channel dequantization with kTfLiteInt2. LiteRT-Converter-PiperOrigin-RevId: 822207279
1 parent f3e74e5 commit d13d4bf

File tree

4 files changed

+11
-1
lines changed

4 files changed

+11
-1
lines changed

tflite/converter/ir/tfl_ops.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4279,7 +4279,7 @@ def TFL_DequantizeOp: TFL_Op<"dequantize", [NoMemoryEffect]> {
42794279
quantization parameters.
42804280
}];
42814281

4282-
let arguments = (ins TFL_TensorOf<[QI4, QI8, QUI8, QI16, F16]>:$input);
4282+
let arguments = (ins TFL_TensorOf<[QI2, QI4, QI8, QUI8, QI16, F16]>:$input);
42834283

42844284
let results = (outs TFL_FpTensor:$output);
42854285

tflite/converter/tools/versioning/op_version.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
499499
return 1;
500500

501501
case BuiltinOperator_DEQUANTIZE:
502+
if (op_sig.inputs.at(0).type == kTfLiteInt2) {
503+
return 7;
504+
}
502505
if (op_sig.inputs.at(0).type == kTfLiteInt4) {
503506
return 6;
504507
}

tflite/converter/tools/versioning/op_version_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,12 @@ TEST(OpVersionTest, VersioningDequantizeTest) {
757757
fake_op_sig.ext_options.dequantize.is_per_channel_quantized = true;
758758
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
759759

760+
fake_op_sig = {
761+
.op = BuiltinOperator_DEQUANTIZE,
762+
.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt2),
763+
};
764+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
765+
760766
fake_op_sig = {
761767
.op = BuiltinOperator_DEQUANTIZE,
762768
.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32),

tflite/converter/tools/versioning/runtime_version.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
326326
{{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"},
327327
{{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"},
328328
{{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"},
329+
{{BuiltinOperator_DEQUANTIZE, 7}, "2.21.0"},
329330
{{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"},
330331
{{BuiltinOperator_EQUAL, 1}, "1.14.0"},
331332
{{BuiltinOperator_EQUAL, 2}, "1.14.0"},

0 commit comments

Comments
 (0)