Skip to content

Commit 8afe70b

Browse files
majiddadashicopybara-github
authored andcommitted
Add support for kTfLiteInt2 (srq) in tfl.fully_connected.
LiteRT-Converter-PiperOrigin-RevId: 822405584
1 parent 945078f commit 8afe70b

File tree

4 files changed

+15
-1
lines changed

4 files changed

+15
-1
lines changed

tflite/converter/ir/tfl_ops.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
11001100

11011101
let arguments = (ins
11021102
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
1103-
TFL_TensorOf<[F32, QI4, QI8, QUI8, QI16]>:$filter,
1103+
TFL_TensorOf<[F32, QI2, QI4, QI8, QUI8, QI16]>:$filter,
11041104
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
11051105

11061106
TFL_AFAttr:$fused_activation_function,

tflite/converter/tools/versioning/op_version.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
177177
reinterpret_cast<TfLiteFullyConnectedParams*>(op_sig.builtin_data);
178178
TFLITE_DCHECK(fully_connected_params != nullptr);
179179

180+
if (op_sig.inputs.at(1).type == kTfLiteInt2) {
181+
return 14;
182+
}
183+
180184
if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
181185
op_sig.inputs.at(1).type == kTfLiteInt4 &&
182186
op_sig.outputs.at(0).type == kTfLiteInt16) {

tflite/converter/tools/versioning/op_version_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,15 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) {
733733
};
734734
fake_op_sig.ext_options.fully_connected.is_per_channel_quantized = true;
735735
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 12);
736+
737+
fake_op_sig = {
738+
.op = BuiltinOperator_FULLY_CONNECTED,
739+
.inputs = CreateOpSignatureTensorSpecs(
740+
std::vector<TfLiteType>{kTfLiteInt8, kTfLiteInt2}),
741+
.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8),
742+
.builtin_data = reinterpret_cast<void*>(&fully_connected_params),
743+
};
744+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 14);
736745
}
737746

738747
TEST(OpVersionTest, VersioningDequantizeTest) {

tflite/converter/tools/versioning/runtime_version.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
139139
{{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"},
140140
{{BuiltinOperator_FULLY_CONNECTED, 12}, "2.17.0"},
141141
{{BuiltinOperator_FULLY_CONNECTED, 13}, "2.18.0"},
142+
{{BuiltinOperator_FULLY_CONNECTED, 14}, "2.21.0"},
142143
{{BuiltinOperator_GATHER, 1}, "1.6.0"},
143144
{{BuiltinOperator_GATHER, 2}, "1.14.0"},
144145
{{BuiltinOperator_GATHER, 3}, "1.15.0"},

0 commit comments

Comments
 (0)