Skip to content

Commit 66801b9

Browse files
committed
enhance tf.where support
tf.where can accept only one input and it will return indices of elem with true value.
1 parent 05b8e0c commit 66801b9

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,16 @@ def test_where_scalar(self):
15021502
_ = tf.identity(picks, name=_TFOUTPUT)
15031503
self._run_test_case([_OUTPUT], {_INPUT: x_val})
15041504

1505+
@check_opset_min_version(9, "where")
1506+
def test_where_with_cond_only(self):
1507+
for np_type, tf_type in [(np.int32, tf.int32), (np.float32, tf.float32)]:
1508+
x_val = np.random.randint(0, 2, size=[10, 20, 30]).astype(np_type)
1509+
x = tf.placeholder(tf_type, shape=[None] * x_val.ndim, name=_TFINPUT)
1510+
res = tf.where(x)
1511+
_ = tf.identity(res, name=_TFOUTPUT)
1512+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1513+
tf.reset_default_graph()
1514+
15051515
@check_opset_min_version(6, "cast")
15061516
def test_shape_int32(self):
15071517
x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=np.float32)

tf2onnx/tfonnx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,15 @@ def logical_compare_op(ctx, node, name, args):
17131713
return nodes
17141714

17151715

1716+
def where_op(ctx, node, name, args):
1717+
# T_y output = Where(T_x condition), return indices of elements whose value are True
1718+
node.type = "NonZero"
1719+
transpose_node = ctx.insert_new_node_on_output("Transpose", node.output[0], name=utils.make_name("where_op_added"))
1720+
ctx.set_shape(transpose_node.output[0], ctx.get_shape(node.output[0]))
1721+
ctx.set_dtype(transpose_node.output[0], ctx.get_dtype(node.output[0]))
1722+
return [node, transpose_node]
1723+
1724+
17161725
# map tensorflow ops to onnx ops. The format below is
17171726
# "TFOP": func_to_map, ["OnnxOp", ...]
17181727
#
@@ -1884,6 +1893,7 @@ def logical_compare_op(ctx, node, name, args):
18841893
"Less": (logical_compare_op, []),
18851894
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18861895
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
1896+
"Where": (where_op, []),
18871897
}
18881898

18891899
_OPSETS = [

0 commit comments

Comments
 (0)