Skip to content

Commit d39427a

Browse files
authored
Merge pull request #300 from pengwa/fix_bing
Fix bugs found in real models.
2 parents 7907fa8 + 5152c2a commit d39427a

File tree

7 files changed

+52
-22
lines changed

7 files changed

+52
-22
lines changed

tests/test_backend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1540,7 +1540,7 @@ def test_sparse_softmax_cross_entropy_with_logits_large_class(self):
15401540

15411541
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val})
15421542

1543-
@unittest.skip("TODO: add a common utility for onnxruntime version check in another PR")
1543+
@unittest.skipIf(BACKEND in ["onnxruntime"], "onnxruntime Slice did not supported BOOL.")
15441544
def test_matrix_band_part(self):
15451545
input_val = np.random.randint(0, 666, (10, 15)).astype(np.int32)
15461546
input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name=_TFINPUT)
@@ -1550,6 +1550,16 @@ def test_matrix_band_part(self):
15501550
_ = tf.identity(res1, name=_TFOUTPUT1)
15511551
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: input_val})
15521552

1553+
@unittest.skipIf(BACKEND in ["onnxruntime"], "onnxruntime Slice did not supported BOOL.")
1554+
def test_matrix_band_part_2(self):
1555+
input_val = np.random.randint(0, 666, (1, 1)).astype(np.int32)
1556+
input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name=_TFINPUT)
1557+
res = tf.matrix_band_part(input_x, -1, 0)
1558+
res1 = tf.matrix_band_part(input_x, 0, -1)
1559+
_ = tf.identity(res, name=_TFOUTPUT)
1560+
_ = tf.identity(res1, name=_TFOUTPUT1)
1561+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: input_val})
1562+
15531563
def test_floordiv(self):
15541564
input_val_1 = np.random.random_sample(100).astype(np.int32)
15551565
input_val_2 = (np.random.random_sample(100) + 1).astype(np.int32)

tf2onnx/function/matrixbandpart.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,18 @@ def matrixbandpart_op(ctx, node, name, args):
4141
g = ctx.create_new_graph_with_same_config()
4242
node_name = utils.make_name("const_zero_bool")
4343
const_zero_bool = ctx.make_const(name=node_name, np_val=np.array([[0]]).astype(np.bool))
44-
slice_node = g.make_node(op_type="Slice", inputs=["line"],
44+
ctx.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)
45+
46+
# shift right the line and add zero at the left.
47+
new_line = g.make_node(op_type="Concat", inputs=[const_zero_bool.output[0], "line"], attr={"axis": counter_axis},
48+
dtypes=[onnx_pb.TensorProto.BOOL])
49+
slice_node = g.make_node(op_type="Slice", inputs=[new_line.output[0]],
4550
attr={"axes": [counter_axis], "starts": [0], "ends": [-1]})
46-
new_line = g.make_node(op_type="Concat", inputs=[const_zero_bool.output[0], slice_node.output[0]],
47-
outputs=["line_out"], attr={"axis": counter_axis})
51+
4852
body_nodes = [slice_node, new_line,
4953
g.make_node("Identity", ["cond"], outputs=["cond_out"]),
50-
g.make_node("Identity", ["line"], outputs=["res"])]
54+
g.make_node("Identity", ["line"], outputs=["res"]),
55+
g.make_node("Identity", [slice_node.output[0]], outputs=["line_out"])]
5156
g.set_nodes(body_nodes)
5257
g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
5358
g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])

tf2onnx/function/select.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_
204204

205205
g.add_graph_output(cond_output_id, TensorProto.BOOL, ())
206206
g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ())
207-
g.add_graph_output(loop_output_id, output_data_type, output_shape[1:])
207+
208+
# use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop.
209+
g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:]))
208210

209211
return g
210212

@@ -234,5 +236,5 @@ def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cu
234236
)
235237
nodes.append(identity_node)
236238
g.set_nodes(nodes)
237-
g.add_graph_output("y", data_type, output_shape)
239+
g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape))
238240
return g

tf2onnx/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -940,9 +940,9 @@ def _extract_sub_graph_nodes(self, dest_node, input_checker=None):
940940
# we don't care about nested graph here, just handle current graph cropping.
941941
node = self.get_node_by_output(input_id, search_in_parent_graphs=False)
942942
if not node:
943-
# some node (for example Scan) has optional inputs, which
944-
# might has empty input.
945-
# subgraph might has input defined in outer graph
943+
# some nodes (for example Scan) have optional inputs, which
944+
# might have empty input.
945+
# subgraph might have input defined in outer graph
946946
continue
947947
if node not in res_set:
948948
if input_checker and input_checker(node) is False:

tf2onnx/rewriter/loop_rewriter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def rewrite(self, context):
5656

5757
body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
5858
body_outputs = cond_g_info.outputs + cell_g_info.outputs
59+
for out_tensor_value_info in body_outputs:
60+
out_tensor_value_info.shape = utils.create_vague_shape_like(out_tensor_value_info.shape)
61+
5962
loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs)
6063

6164
# create loop body graph inputs
@@ -74,7 +77,7 @@ def rewrite(self, context):
7477
dtype = tensor_value_info.dtype
7578
shape = tensor_value_info.shape
7679

77-
loop_body_g.add_graph_input(input_name, dtype, shape)
80+
loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape))
7881

7982
body_nodes_to_append = []
8083
for input_ta in loop_props.tensor_array_inputs:

tf2onnx/tfonnx.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,17 +1300,10 @@ def minmax_op(ctx, node, name, args):
13001300

13011301
def pack_op(ctx, node, name, args):
13021302
# hack to make up for the missing onnx pack op
1303-
1304-
pack_shape = ctx.get_shape(node.output[0])
1305-
if not pack_shape:
1306-
# sometimes Pack output shape is None (for example Pack is following control flow Exit op)
1307-
input_cnt = len(node.inputs)
1308-
input_shape = ctx.get_shape(node.input[0])
1309-
if input_shape:
1310-
pack_shape = [input_cnt] + input_shape
1311-
ctx.set_shape(node.output[0], pack_shape)
1312-
13131303
axis = node.get_attr("axis").i
1304+
if axis < 0:
1305+
axis += len(ctx.get_shape(node.input[0])) + 1
1306+
13141307
nodes = []
13151308
inputs = []
13161309
dtype = None
@@ -1747,6 +1740,14 @@ def zeroslike_op(ctx, node, name, args):
17471740
return mul_op
17481741

17491742

1743+
def softmax_op(ctx, node, name, args):
1744+
# T output = Softmax(T logits). The axis softmax would be performed on is always on -1.
1745+
# T output = Softmax(T input, @int axis). Default axis is 1.
1746+
logits_rank = len(ctx.get_shape(node.input[0]))
1747+
node.set_attr("axis", logits_rank - 1)
1748+
return node
1749+
1750+
17501751
# map tensorflow ops to onnx ops. The format below is
17511752
# "TFOP": func_to_map, ["OnnxOp", ...]
17521753
#
@@ -1829,7 +1830,7 @@ def zeroslike_op(ctx, node, name, args):
18291830
"Sqrt": (direct_op, []),
18301831
"Square": (square_op, []),
18311832
"SquaredDifference": (squareddifference_op, []),
1832-
"Softmax": (direct_op, ["Softmax"]),
1833+
"Softmax": (softmax_op, ["Softmax"]),
18331834
"StopGradient": (identity_op, ["Identity"]),
18341835
"StridedSlice": (stridedslice_op, []),
18351836
"Sub": (broadcast_op, []),
@@ -2261,6 +2262,10 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
22612262

22622263
input_name = op.input[i]
22632264
dtype = g.get_dtype(input_name)
2265+
if dtype is None:
2266+
log.warning("Adding Cast for op %s (type is %s)' input: %s, dtype should not be None",
2267+
op.name, op.type, input_name)
2268+
22642269
if dtype != onnx_pb.TensorProto.FLOAT:
22652270
output_dtype = dtype
22662271
if input_node and input_node.type == "Cast" \

tf2onnx/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,8 @@ def construct_graph_from_nodes(parent_g, nodes, outputs, shapes, dtypes):
335335

336336
def tf_name_scope(name):
337337
return '/'.join(name.split('/')[:-1])
338+
339+
340+
def create_vague_shape_like(shape):
341+
make_sure(len(shape) >= 0, "rank should be >= 0")
342+
return [-1 for i in enumerate(shape)]

0 commit comments

Comments
 (0)