Skip to content

Commit 8dbde42

Browse files
Minor bugfixes to support tflite (#1275)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c4a2142 commit 8dbde42

File tree

5 files changed

+35
-7
lines changed

5 files changed

+35
-7
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def version_1(cls, ctx, node, **kwargs):
268268
branches[branch] = g
269269

270270
_ = ctx.make_node("If", node.input[:1], name=node.name, output_count=len(output_shapes),
271-
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True, branches=branches)
271+
shapes=output_shapes, dtypes=output_dtypes, outputs=node.output, skip_conversion=True,
272+
branches=branches)
272273

273274

274275
@tf_op(["TensorListSetItem"])
@@ -629,6 +630,7 @@ def parameter_binding(g, inputs, state_vars=None):
629630
else:
630631
binding[k] = inputs[i]
631632
i += 1
633+
utils.make_sure(i == len(inputs), "Parameter count mismatch while binding controlflow")
632634
return binding
633635

634636

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,10 @@ def version_1(cls, ctx, node, **kwargs):
577577
if len(kernel_shape) != 4:
578578
raise ValueError("only Conv2D is supported")
579579
k_h, k_w, k_input_channels, k_channel_multiplier = kernel_shape
580+
if "depth_multiplier" in node.attr:
581+
depth_multiplier = node.get_attr_int("depth_multiplier")
582+
k_input_channels //= depth_multiplier
583+
k_channel_multiplier *= depth_multiplier
580584
if k_input_channels < 1:
581585
raise ValueError("input channel must be positive")
582586
k_output_channels = k_input_channels * k_channel_multiplier

tf2onnx/onnx_opset/tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,8 +1148,10 @@ def any_version(cls, opset, ctx, node, **kwargs):
11481148
# insert Unsqueeze on each input
11491149
for i, n in enumerate(node.inputs):
11501150
dtype = ctx.get_dtype(node.input[i])
1151-
shape = ctx.get_shape(node.input[i]).copy()
1152-
shape.insert(axis, 1)
1151+
shape = ctx.get_shape(node.input[i])
1152+
if shape is not None:
1153+
shape = shape.copy()
1154+
shape.insert(axis, 1)
11531155
new_node = gb.make_unsqueeze(
11541156
{'data': node.input[i], 'axes': [axis]},
11551157
op_name_scope=node.name, shapes=[shape], dtypes=[dtype], return_node=True)

tf2onnx/rewriter/eye_rewriter.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,29 @@ def rewrite_eye(g, ops):
110110
OpTypePattern("Const", name="fill_value"),
111111
]), "*"
112112
])
113+
pattern7 = \
114+
OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
115+
OpTypePattern("Fill", inputs=[
116+
OpTypePattern("Reshape", inputs=[
117+
OpTypePattern("Minimum|Cast", name="min_or_cast"),
118+
"*",
119+
]),
120+
OpTypePattern("Const", name="fill_value"),
121+
])
122+
])
123+
pattern8 = \
124+
OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
125+
OpTypePattern("Fill"),
126+
OpTypePattern("Fill", inputs=[
127+
OpTypePattern("Reshape", inputs=[
128+
OpTypePattern("Minimum|Cast", name="min_or_cast"),
129+
"*",
130+
]),
131+
OpTypePattern("Const", name="fill_value"),
132+
])
133+
])
113134

114-
for pattern in [pattern1, pattern2, pattern3, pattern4, pattern5, pattern6]:
135+
for pattern in [pattern1, pattern2, pattern3, pattern4, pattern5, pattern6, pattern7, pattern8]:
115136
matcher = GraphMatcher(pattern, allow_reorder=True)
116137
match_results = list(matcher.match_ops(ops))
117138
for match_result in match_results:
@@ -146,6 +167,6 @@ def rewrite_eye(g, ops):
146167
zero_matrix = g.make_node("ConstantOfShape", matrix_shape_int64.output)
147168

148169
g.make_node("EyeLike", zero_matrix.output, attr={"dtype": output_dtypes[0]},
149-
name=old_output.name, shapes=output_shapes, dtypes=output_dtypes)
170+
name=old_output.name, shapes=output_shapes, dtypes=output_dtypes, outputs=old_output.output)
150171

151172
return g.get_nodes()

tf2onnx/tflite_handlers/tfl_nn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class TflDepthwiseConv2D:
7070
@classmethod
7171
def to_tf(cls, ctx, node, **kwargs):
7272
separate_fused_activation_function(ctx, node)
73-
# No need to change 'padding' attribute
73+
# No need to change 'padding' or 'depth_multiplier' attributes
7474
stride_h = node.get_attr_int("stride_h")
7575
stride_w = node.get_attr_int("stride_w")
7676
dilation_w_factor = node.get_attr_int("dilation_w_factor")
@@ -81,7 +81,6 @@ def to_tf(cls, ctx, node, **kwargs):
8181
del node.attr["stride_w"]
8282
del node.attr["dilation_h_factor"]
8383
del node.attr["dilation_w_factor"]
84-
del node.attr["depth_multiplier"] # TODO: use this?
8584
transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1], name=None, perm=[1, 2, 3, 0])
8685
transpose_node.skip_conversion = True
8786
node.set_attr("data_format", "NHWC")

0 commit comments

Comments
 (0)