Skip to content

Commit 8697dbf

Browse files
majiddadashicopybara-github
authored andcommitted
Add support for int2/int4 in tfl.cast
LiteRT-Converter-PiperOrigin-RevId: 820509011
1 parent 75e7f13 commit 8697dbf

File tree

5 files changed

+81
-11
lines changed

5 files changed

+81
-11
lines changed

tflite/converter/ir/tfl_ops.td

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
112112
Variadic<TensorOf<allowedOpTypes>>,
113113
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
114114

115+
def TFL_I2 : I<2>;
115116
def TFL_I4 : I<4>;
116117
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
117118

@@ -4072,13 +4073,10 @@ def TFL_CastOp : TFL_Op<"cast", [
40724073
}];
40734074

40744075
let arguments = (ins
4075-
TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
4076+
TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
40764077
);
40774078

4078-
// TODO(b/393644251): Temporary support for INT4 TFL_CastOp. Runtime
4079-
// probably already supports INT4. We should remove the INT4 support here or
4080-
// make sure the runtime supports is there, as part of closing the bug.
4081-
let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
4079+
let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
40824080

40834081
// TFLite's cast op does not utilize CastOptions, instead derives types
40844082
// from the TfLiteTensors.

tflite/converter/tools/versioning/op_version.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,8 +1073,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
10731073
}
10741074
return 2;
10751075
case BuiltinOperator_CAST:
1076-
if (op_sig.inputs.at(0).type == kTfLiteBFloat16 ||
1077-
op_sig.outputs.at(0).type == kTfLiteBFloat16) {
1076+
if (op_sig.inputs.at(0).type == kTfLiteInt2 ||
1077+
op_sig.outputs.at(0).type == kTfLiteInt2) {
1078+
return 8;
1079+
} else if (op_sig.inputs.at(0).type == kTfLiteBFloat16 ||
1080+
op_sig.outputs.at(0).type == kTfLiteBFloat16) {
10781081
return 7;
10791082
} else if (op_sig.inputs.at(0).type == kTfLiteInt4 &&
10801083
op_sig.outputs.at(0).type == kTfLiteFloat32) {

tflite/converter/tools/versioning/op_version_test.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,4 +1467,72 @@ TEST(OpVersionTest, VersioningSqrtTest) {
14671467
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
14681468
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
14691469
}
1470+
1471+
TEST(OpVersionTest, VersioningCastTest) {
1472+
OpSignature fake_op_sig = {};
1473+
fake_op_sig.op = BuiltinOperator_CAST;
1474+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt2);
1475+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1476+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
1477+
1478+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1479+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt2);
1480+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
1481+
1482+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16);
1483+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1484+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
1485+
1486+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1487+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16);
1488+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
1489+
1490+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt4);
1491+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
1492+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
1493+
1494+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64);
1495+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1496+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
1497+
1498+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1499+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64);
1500+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
1501+
1502+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16);
1503+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1504+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
1505+
1506+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1507+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16);
1508+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
1509+
1510+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16);
1511+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1512+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
1513+
1514+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1515+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16);
1516+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
1517+
1518+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
1519+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1520+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
1521+
1522+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1523+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
1524+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
1525+
1526+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32);
1527+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1528+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
1529+
1530+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1531+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32);
1532+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
1533+
1534+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1535+
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
1536+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
1537+
}
14701538
} // namespace tflite

tflite/converter/tools/versioning/runtime_version.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
112112
{{BuiltinOperator_CAST, 5}, "2.12.0"},
113113
{{BuiltinOperator_CAST, 6}, "2.15.0"},
114114
{{BuiltinOperator_CAST, 7}, "2.17.0"},
115+
{{BuiltinOperator_CAST, 8}, "2.21.0"},
115116
{{BuiltinOperator_CONCATENATION, 1}, "1.5.0"},
116117
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
117118
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},

tflite/converter/transforms/tf_legalizations/while_loop_outline_pass.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ bool IsCompatibleTypeWithTFLCastOp(Type type) {
5959
elemType.isF64())
6060
return true;
6161

62-
// I1, I4, I8, I16, I32, I64 types are allowed.
63-
if (elemType.isInteger(1) || elemType.isInteger(4) || elemType.isInteger(8) ||
64-
elemType.isInteger(16) || elemType.isInteger(32) ||
65-
elemType.isInteger(64))
62+
// I1, I2, I4, I8, I16, I32, I64 types are allowed.
63+
if (elemType.isInteger(1) || elemType.isInteger(2) || elemType.isInteger(4) ||
64+
elemType.isInteger(8) || elemType.isInteger(16) ||
65+
elemType.isInteger(32) || elemType.isInteger(64))
6666
return true;
6767

6868
// Complex<F<32>> is allowed.

0 commit comments

Comments
 (0)