Skip to content

Commit e9a05d0

Browse files
Fix issues with tflite on tf2.6 (#1681)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c02f793 commit e9a05d0

File tree

137 files changed

+2777
-643
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+2777
-643
lines changed

ci_build/azure_pipelines/unit_test.yml

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,9 @@ stages:
55
jobs:
66
- template: 'templates/job_generator.yml'
77
parameters:
8-
# tf 2.5
9-
python_versions: ['3.8']
10-
tf_versions: ['2.6.0rc2']
11-
onnx_opsets: ['14']
12-
job:
13-
steps:
14-
- template: 'unit_test.yml'
15-
report_coverage: 'True'
16-
17-
- template: 'templates/job_generator.yml'
18-
parameters:
19-
# TFJS tf 2.5
8+
# TFJS tf 2.6
209
python_versions: ['3.9']
21-
tf_versions: ['2.5.0']
10+
tf_versions: ['2.6.0']
2211
onnx_opsets: ['']
2312
skip_tfjs_tests: 'False'
2413
skip_tf_tests: 'True'
@@ -29,9 +18,9 @@ stages:
2918

3019
- template: 'templates/job_generator.yml'
3120
parameters:
32-
# TFLite tf 2.5
21+
# TFLite tf 2.6
3322
python_versions: ['3.8']
34-
tf_versions: ['2.5.0']
23+
tf_versions: ['2.6.0']
3524
onnx_opsets: ['']
3625
skip_tflite_tests: 'False'
3726
skip_tf_tests: 'True'
@@ -42,19 +31,20 @@ stages:
4231

4332
- template: 'templates/job_generator.yml'
4433
parameters:
45-
# tf 2.5
34+
# tf 2.6
4635
python_versions: ['3.8']
47-
tf_versions: ['2.5.0']
36+
tf_versions: ['2.6.0']
4837
onnx_opsets: ['']
4938
job:
5039
steps:
5140
- template: 'unit_test.yml'
5241
report_coverage: 'True'
5342

5443
- template: 'templates/job_generator.yml'
44+
# tf 2.5, tf 1.15
5545
parameters:
5646
python_versions: ['3.7']
57-
tf_versions: ['1.15.2','2.3.0']
47+
tf_versions: ['1.15.2','2.5.0']
5848
onnx_opsets: ['']
5949
job:
6050
steps:

tests/test_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2879,6 +2879,7 @@ def validate_graph(g):
28792879

28802880
@check_tf_min_version("1.15")
28812881
@check_opset_min_version(11, "ScatterND")
2882+
@skip_tflite("TFLite uses a pattern for ScatterND so number of DequantizeLinear won't match")
28822883
def test_qdq_optimizer_scatter(self):
28832884
x_val = np.array([10, 20, 30, 40], dtype=np.float32).reshape((4))
28842885
y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))
@@ -4485,6 +4486,7 @@ def func(x, y, z):
44854486

44864487
@check_tf_min_version("1.15", "tensor_scatter_nd_update for strings needs tf 1.15")
44874488
@check_opset_min_version(11, "ScatterND")
4489+
@skip_tflite("Conversion crashes")
44884490
def test_tensor_scatter_update_str(self):
44894491
x_val = np.array(['A', '♠♣♥♦', 'B', 'C'], dtype=np.str).reshape((4))
44904492
y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))
@@ -4497,6 +4499,7 @@ def func(x, y, z):
44974499

44984500
@check_tf_min_version("1.15", "tensor_scatter_nd_update for strings needs tf 1.15")
44994501
@check_opset_min_version(11, "ScatterND")
4502+
@skip_tflite("Conversion crashes")
45004503
def test_tensor_scatter_update_str_const(self):
45014504
x_val = np.array(['A', '♠♣♥♦', 'B', 'C'], dtype=np.str).reshape((4))
45024505
y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))

tf2onnx/onnx_opset/tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,10 +645,12 @@ class ScatterND:
645645
@classmethod
646646
def version_11(cls, ctx, node, **kwargs):
647647
onnxdtype = ctx.get_dtype(node.input[1])
648-
const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2])
648+
zero_tensor = helper.make_tensor("value", onnxdtype, dims=[1], vals=[0])
649+
const_of_shape = ctx.make_node("ConstantOfShape", [node.input[2]], attr={'value': zero_tensor},
650+
shapes=node.output_shapes, dtypes=[onnxdtype])
651+
ctx.replace_input(node, node.input[2], const_of_shape.output[0], 2)
649652
ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64)
650653
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64)
651-
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnxdtype)
652654
# reorder inputs to match onnx
653655
ctx.replace_inputs(node, [node.input[2], node.input[0], node.input[1]])
654656

tf2onnx/tflite/AbsOptions.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,29 @@ class AbsOptions(object):
1212
__slots__ = ['_tab']
1313

1414
@classmethod
15-
def GetRootAsAbsOptions(cls, buf, offset):
15+
def GetRootAs(cls, buf, offset=0):
1616
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
1717
x = AbsOptions()
1818
x.Init(buf, n + offset)
1919
return x
2020

2121
@classmethod
22+
def GetRootAsAbsOptions(cls, buf, offset=0):
23+
"""This method is deprecated. Please switch to GetRootAs."""
24+
return cls.GetRootAs(buf, offset)
25+
@classmethod
2226
def AbsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
2327
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
2428

2529
# AbsOptions
2630
def Init(self, buf, pos):
2731
self._tab = flatbuffers.table.Table(buf, pos)
2832

29-
def AbsOptionsStart(builder): builder.StartObject(0)
30-
def AbsOptionsEnd(builder): return builder.EndObject()
33+
def Start(builder): builder.StartObject(0)
34+
def AbsOptionsStart(builder):
35+
"""This method is deprecated. Please switch to Start."""
36+
return Start(builder)
37+
def End(builder): return builder.EndObject()
38+
def AbsOptionsEnd(builder):
39+
"""This method is deprecated. Please switch to End."""
40+
return End(builder)

tf2onnx/tflite/AddNOptions.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,29 @@ class AddNOptions(object):
1212
__slots__ = ['_tab']
1313

1414
@classmethod
15-
def GetRootAsAddNOptions(cls, buf, offset):
15+
def GetRootAs(cls, buf, offset=0):
1616
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
1717
x = AddNOptions()
1818
x.Init(buf, n + offset)
1919
return x
2020

2121
@classmethod
22+
def GetRootAsAddNOptions(cls, buf, offset=0):
23+
"""This method is deprecated. Please switch to GetRootAs."""
24+
return cls.GetRootAs(buf, offset)
25+
@classmethod
2226
def AddNOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
2327
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
2428

2529
# AddNOptions
2630
def Init(self, buf, pos):
2731
self._tab = flatbuffers.table.Table(buf, pos)
2832

29-
def AddNOptionsStart(builder): builder.StartObject(0)
30-
def AddNOptionsEnd(builder): return builder.EndObject()
33+
def Start(builder): builder.StartObject(0)
34+
def AddNOptionsStart(builder):
35+
"""This method is deprecated. Please switch to Start."""
36+
return Start(builder)
37+
def End(builder): return builder.EndObject()
38+
def AddNOptionsEnd(builder):
39+
"""This method is deprecated. Please switch to End."""
40+
return End(builder)

tf2onnx/tflite/AddOptions.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ class AddOptions(object):
1212
__slots__ = ['_tab']
1313

1414
@classmethod
15-
def GetRootAsAddOptions(cls, buf, offset):
15+
def GetRootAs(cls, buf, offset=0):
1616
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
1717
x = AddOptions()
1818
x.Init(buf, n + offset)
1919
return x
2020

2121
@classmethod
22+
def GetRootAsAddOptions(cls, buf, offset=0):
23+
"""This method is deprecated. Please switch to GetRootAs."""
24+
return cls.GetRootAs(buf, offset)
25+
@classmethod
2226
def AddOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
2327
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
2428

@@ -40,7 +44,19 @@ def PotScaleInt16(self):
4044
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
4145
return True
4246

43-
def AddOptionsStart(builder): builder.StartObject(2)
44-
def AddOptionsAddFusedActivationFunction(builder, fusedActivationFunction): builder.PrependInt8Slot(0, fusedActivationFunction, 0)
45-
def AddOptionsAddPotScaleInt16(builder, potScaleInt16): builder.PrependBoolSlot(1, potScaleInt16, 1)
46-
def AddOptionsEnd(builder): return builder.EndObject()
47+
def Start(builder): builder.StartObject(2)
48+
def AddOptionsStart(builder):
49+
"""This method is deprecated. Please switch to Start."""
50+
return Start(builder)
51+
def AddFusedActivationFunction(builder, fusedActivationFunction): builder.PrependInt8Slot(0, fusedActivationFunction, 0)
52+
def AddOptionsAddFusedActivationFunction(builder, fusedActivationFunction):
53+
"""This method is deprecated. Please switch to AddFusedActivationFunction."""
54+
return AddFusedActivationFunction(builder, fusedActivationFunction)
55+
def AddPotScaleInt16(builder, potScaleInt16): builder.PrependBoolSlot(1, potScaleInt16, 1)
56+
def AddOptionsAddPotScaleInt16(builder, potScaleInt16):
57+
"""This method is deprecated. Please switch to AddPotScaleInt16."""
58+
return AddPotScaleInt16(builder, potScaleInt16)
59+
def End(builder): return builder.EndObject()
60+
def AddOptionsEnd(builder):
61+
"""This method is deprecated. Please switch to End."""
62+
return End(builder)

tf2onnx/tflite/ArgMaxOptions.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ class ArgMaxOptions(object):
1212
__slots__ = ['_tab']
1313

1414
@classmethod
15-
def GetRootAsArgMaxOptions(cls, buf, offset):
15+
def GetRootAs(cls, buf, offset=0):
1616
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
1717
x = ArgMaxOptions()
1818
x.Init(buf, n + offset)
1919
return x
2020

2121
@classmethod
22+
def GetRootAsArgMaxOptions(cls, buf, offset=0):
23+
"""This method is deprecated. Please switch to GetRootAs."""
24+
return cls.GetRootAs(buf, offset)
25+
@classmethod
2226
def ArgMaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
2327
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
2428

@@ -33,6 +37,15 @@ def OutputType(self):
3337
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
3438
return 0
3539

36-
def ArgMaxOptionsStart(builder): builder.StartObject(1)
37-
def ArgMaxOptionsAddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0)
38-
def ArgMaxOptionsEnd(builder): return builder.EndObject()
40+
def Start(builder): builder.StartObject(1)
41+
def ArgMaxOptionsStart(builder):
42+
"""This method is deprecated. Please switch to Start."""
43+
return Start(builder)
44+
def AddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0)
45+
def ArgMaxOptionsAddOutputType(builder, outputType):
46+
"""This method is deprecated. Please switch to AddOutputType."""
47+
return AddOutputType(builder, outputType)
48+
def End(builder): return builder.EndObject()
49+
def ArgMaxOptionsEnd(builder):
50+
"""This method is deprecated. Please switch to End."""
51+
return End(builder)

tf2onnx/tflite/ArgMinOptions.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ class ArgMinOptions(object):
1212
__slots__ = ['_tab']
1313

1414
@classmethod
15-
def GetRootAsArgMinOptions(cls, buf, offset):
15+
def GetRootAs(cls, buf, offset=0):
1616
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
1717
x = ArgMinOptions()
1818
x.Init(buf, n + offset)
1919
return x
2020

2121
@classmethod
22+
def GetRootAsArgMinOptions(cls, buf, offset=0):
23+
"""This method is deprecated. Please switch to GetRootAs."""
24+
return cls.GetRootAs(buf, offset)
25+
@classmethod
2226
def ArgMinOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
2327
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
2428

@@ -33,6 +37,15 @@ def OutputType(self):
3337
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
3438
return 0
3539

36-
def ArgMinOptionsStart(builder): builder.StartObject(1)
37-
def ArgMinOptionsAddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0)
38-
def ArgMinOptionsEnd(builder): return builder.EndObject()
40+
def Start(builder): builder.StartObject(1)
41+
def ArgMinOptionsStart(builder):
42+
"""This method is deprecated. Please switch to Start."""
43+
return Start(builder)
44+
def AddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0)
45+
def ArgMinOptionsAddOutputType(builder, outputType):
46+
"""This method is deprecated. Please switch to AddOutputType."""
47+
return AddOutputType(builder, outputType)
48+
def End(builder): return builder.EndObject()
49+
def ArgMinOptionsEnd(builder):
50+
"""This method is deprecated. Please switch to End."""
51+
return End(builder)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# automatically generated by the FlatBuffers compiler, do not modify
4+
5+
# namespace: tflite
6+
7+
import flatbuffers
8+
from flatbuffers.compat import import_numpy
9+
np = import_numpy()
10+
11+
class AssignVariableOptions(object):
12+
__slots__ = ['_tab']
13+
14+
@classmethod
15+
def GetRootAs(cls, buf, offset=0):
16+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17+
x = AssignVariableOptions()
18+
x.Init(buf, n + offset)
19+
return x
20+
21+
@classmethod
22+
def GetRootAsAssignVariableOptions(cls, buf, offset=0):
23+
"""This method is deprecated. Please switch to GetRootAs."""
24+
return cls.GetRootAs(buf, offset)
25+
@classmethod
26+
def AssignVariableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28+
29+
# AssignVariableOptions
30+
def Init(self, buf, pos):
31+
self._tab = flatbuffers.table.Table(buf, pos)
32+
33+
def Start(builder): builder.StartObject(0)
34+
def AssignVariableOptionsStart(builder):
35+
"""This method is deprecated. Please switch to Start."""
36+
return Start(builder)
37+
def End(builder): return builder.EndObject()
38+
def AssignVariableOptionsEnd(builder):
39+
"""This method is deprecated. Please switch to End."""
40+
return End(builder)

tf2onnx/tflite/BatchMatMulOptions.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ class BatchMatMulOptions(object):
1212
__slots__ = ['_tab']
1313

1414
@classmethod
15-
def GetRootAsBatchMatMulOptions(cls, buf, offset):
15+
def GetRootAs(cls, buf, offset=0):
1616
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
1717
x = BatchMatMulOptions()
1818
x.Init(buf, n + offset)
1919
return x
2020

2121
@classmethod
22+
def GetRootAsBatchMatMulOptions(cls, buf, offset=0):
23+
"""This method is deprecated. Please switch to GetRootAs."""
24+
return cls.GetRootAs(buf, offset)
25+
@classmethod
2226
def BatchMatMulOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
2327
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
2428

@@ -47,8 +51,23 @@ def AsymmetricQuantizeInputs(self):
4751
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
4852
return False
4953

50-
def BatchMatMulOptionsStart(builder): builder.StartObject(3)
51-
def BatchMatMulOptionsAddAdjX(builder, adjX): builder.PrependBoolSlot(0, adjX, 0)
52-
def BatchMatMulOptionsAddAdjY(builder, adjY): builder.PrependBoolSlot(1, adjY, 0)
53-
def BatchMatMulOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): builder.PrependBoolSlot(2, asymmetricQuantizeInputs, 0)
54-
def BatchMatMulOptionsEnd(builder): return builder.EndObject()
54+
def Start(builder): builder.StartObject(3)
55+
def BatchMatMulOptionsStart(builder):
56+
"""This method is deprecated. Please switch to Start."""
57+
return Start(builder)
58+
def AddAdjX(builder, adjX): builder.PrependBoolSlot(0, adjX, 0)
59+
def BatchMatMulOptionsAddAdjX(builder, adjX):
60+
"""This method is deprecated. Please switch to AddAdjX."""
61+
return AddAdjX(builder, adjX)
62+
def AddAdjY(builder, adjY): builder.PrependBoolSlot(1, adjY, 0)
63+
def BatchMatMulOptionsAddAdjY(builder, adjY):
64+
"""This method is deprecated. Please switch to AddAdjY."""
65+
return AddAdjY(builder, adjY)
66+
def AddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): builder.PrependBoolSlot(2, asymmetricQuantizeInputs, 0)
67+
def BatchMatMulOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs):
68+
"""This method is deprecated. Please switch to AddAsymmetricQuantizeInputs."""
69+
return AddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs)
70+
def End(builder): return builder.EndObject()
71+
def BatchMatMulOptionsEnd(builder):
72+
"""This method is deprecated. Please switch to End."""
73+
return End(builder)

0 commit comments

Comments
 (0)