Skip to content

Commit 9d22804

Browse files
authored
Merge pull request #333 from zhijxu-MS/push_branch
bug fixe fill_op, enhance tf.where, enhance gather shape inference
2 parents bee1f80 + 97d6f4f commit 9d22804

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,16 @@ def test_where_scalar(self):
15021502
_ = tf.identity(picks, name=_TFOUTPUT)
15031503
self._run_test_case([_OUTPUT], {_INPUT: x_val})
15041504

1505+
@check_opset_min_version(9, "where")
1506+
def test_where_with_cond_only(self):
1507+
for np_type, tf_type in [(np.int32, tf.int32), (np.float32, tf.float32)]:
1508+
x_val = np.random.randint(0, 2, size=[10, 20, 30]).astype(np_type)
1509+
x = tf.placeholder(tf_type, shape=[None] * x_val.ndim, name=_TFINPUT)
1510+
res = tf.where(x)
1511+
_ = tf.identity(res, name=_TFOUTPUT)
1512+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1513+
tf.reset_default_graph()
1514+
15051515
@check_opset_min_version(6, "cast")
15061516
def test_shape_int32(self):
15071517
x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=np.float32)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,11 +398,11 @@ def _maxmin_handler(self, trans, node):
398398
for i in all_other_inputs:
399399
target_node = self._g.get_node_by_output(i)
400400
numpy_val = target_node.get_tensor_value(as_list=False)
401-
rank = np.rank(numpy_val)
401+
rank = numpy_val.ndim
402402
if rank == 4:
403403
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
404404
target_node.set_tensor_value(transposed_val)
405-
elif rank == 1: # scalar
405+
elif rank == 1: # scalar
406406
# do nothing
407407
pass
408408
else:

tf2onnx/shape_inference.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ def infer_shape_for_node(g, node):
141141
log.debug("set ConcatV2 node [%s] with new shape %s", node.output[0], new_shape)
142142
return True
143143

144+
if node.type == "Gather":
145+
# uses the follwing link to know how to infer shape of output
146+
# https://www.tensorflow.org/api_docs/python/tf/gather
147+
shape_params = g.get_shape(node.input[0])
148+
shape_indices = g.get_shape(node.input[1])
149+
axis = node.input[2].get_tensor_value()
150+
151+
shape = shape_params[:axis] + shape_indices + shape_indices[axis+1:]
152+
g.set_shape(node.output[0], shape)
153+
return True
154+
144155
if node.type in ["All", "Any", "Min"]:
145156
axis_node = node.inputs[1]
146157
axis = axis_node.get_tensor_value()

tf2onnx/tfonnx.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,17 +1463,16 @@ def fill_op(ctx, node, name, args):
14631463
# both shape and value in tensorflow are passed as tensor.
14641464
# In onnx the value is an attribute so we need to fetch the value as const which
14651465
# sooner or later will be a problem for tensorflow-onnx.
1466-
shape = ctx.get_shape(node.output[0])
1467-
utils.make_sure(all(i >= 0 for i in shape), "shape attr should not be less than zero")
1468-
value = node.inputs[1].get_tensor_value()
1469-
value_proto = numpy_helper.from_array(node.inputs[1].get_tensor_value(as_list=False))
1470-
dtype = value_proto.data_type
1471-
# onnx spec says value MUST be float.
1472-
node.set_attr("value", float(value))
1473-
node.set_attr("shape", shape)
1474-
node.set_attr("dtype", dtype)
1475-
del node.input[:]
1476-
return node
1466+
# ConstantOfShape in onnxruntime only support int64, so insert cast op
1467+
input_dtype_is_int64 = utils.ONNX_TO_NUMPY_DTYPE[ctx.get_dtype(node.input[0])] == np.int64
1468+
if not input_dtype_is_int64:
1469+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.INT64)
1470+
dtype = ctx.get_dtype(node.output[0])
1471+
value = np.array([node.inputs[1].get_tensor_value()]).astype(utils.ONNX_TO_NUMPY_DTYPE[dtype])
1472+
value_proto = numpy_helper.from_array(value)
1473+
node.set_attr("value", value_proto)
1474+
del node.input[1]
1475+
return [node] if input_dtype_is_int64 else [node, cast_node]
14771476

14781477

14791478
def reverse_op8(ctx, node, name, args):
@@ -1723,6 +1722,18 @@ def logical_compare_op(ctx, node, name, args):
17231722
return nodes
17241723

17251724

1725+
def where_op(ctx, node, name, args):
1726+
# T_y output = Where(T_x condition), return indices of elements whose value are True
1727+
node.type = "NonZero"
1728+
# in onnx, indices are returned in this way [[ind_a_0, ind_b_0, ...], [ind_a_1, ind_b_1,...]];
1729+
# while in tf, the result will be [[ind_a_0, ind_a_1, ...], [ind_b_0, ind_b_1, ...], ...]
1730+
# this is the reason a transpose node inserted here.
1731+
transpose_node = ctx.insert_new_node_on_output("Transpose", node.output[0], name=utils.make_name("where_op_added"))
1732+
ctx.copy_shape(node.output[0], transpose_node.output[0])
1733+
ctx.copy_dtype(node.output[0], transpose_node.output[0])
1734+
return [node, transpose_node]
1735+
1736+
17261737
# map tensorflow ops to onnx ops. The format below is
17271738
# "TFOP": func_to_map, ["OnnxOp", ...]
17281739
#
@@ -1894,6 +1905,7 @@ def logical_compare_op(ctx, node, name, args):
18941905
"Less": (logical_compare_op, []),
18951906
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18961907
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
1908+
"Where": (where_op, []),
18971909
}
18981910

18991911
_OPSETS = [

0 commit comments

Comments
 (0)