Skip to content

Commit c2c9164

Browse files
fix bugs
1 parent 244e90a commit c2c9164

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,12 @@ class Unpack:
941941
@classmethod
942942
def version_1(cls, ctx, node, **kwargs):
943943
# hack to make up for the missing onnx unpack op
944+
# squeeze does not support negative axis
944945
axis = node.get_attr("axis").i
946+
if axis < 0:
947+
shape = ctx.get_shape(node.input[0])
948+
utils.make_sure(shape is not None, "shape of unpack input is None: {}".format(node.input[0]))
949+
axis += len(shape)
945950
# split the tensor into n outputs
946951
node.type = "Split"
947952
# for each output we need to squeeze axis

tf2onnx/rewriter/leakyrelu_rewriter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def rewrite_leakyrelu(g, ops):
4040
continue
4141
leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha},
4242
shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
43-
ops.remove(max_node)
44-
ops.remove(mul_node)
4543
ops.append(leakyrelu)
4644
g.replace_all_inputs(ops, max_node.output[0], leakyrelu.output[0])
4745

tf2onnx/rewriter/thresholded_relu_rewriter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def rewrite_thresholded_relu(g, ops):
3434
greater_input_node = match.get_op('greater_input')
3535
mul_node = match.get_op("mul")
3636
mul_input_node = match.get_op('mul_input')
37-
cast_node = match.get_op('cast')
3837

3938
greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
4039
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
@@ -43,9 +42,6 @@ def rewrite_thresholded_relu(g, ops):
4342
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
4443
shapes=[g.get_shape(mul_node.output[0])],
4544
dtypes=[g.get_dtype(mul_node.output[0])])
46-
ops.remove(greater_node)
47-
ops.remove(cast_node)
48-
ops.remove(mul_node)
4945
ops.append(thresholded_relu)
5046
g.replace_all_inputs(ops, mul_node.output[0], thresholded_relu.output[0])
5147
return ops

0 commit comments

Comments
 (0)