Skip to content

Commit 6ae0320

Browse files
authored
Merge pull request #475 from zhijxu-MS/topk
support topk-10
2 parents 1f85897 + 9cabb52 commit 6ae0320

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

tests/test_backend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1333,13 +1333,23 @@ def test_cancel_transpose(self):
13331333
self._run_test_case([_OUTPUT], {_INPUT: x_val})
13341334

13351335
@check_opset_min_version(6, "cast")
1336-
def test_topk(self):
1336+
def test_topk1(self):
13371337
x_val = np.arange(3 * 2 * 3).astype("float32")
13381338
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
13391339
values, _ = tf.nn.top_k(x, 5, sorted=True)
13401340
_ = tf.identity(values, name=_TFOUTPUT)
13411341
self._run_test_case([_OUTPUT], {_INPUT: x_val})
13421342

1343+
@check_opset_min_version(10, "TopK with dynamic K")
1344+
def test_topk2(self):
1345+
x_val = np.arange(3 * 2 * 3).astype("float32")
1346+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1347+
k_val = np.array(10).astype(np.int32)
1348+
k = tf.placeholder(tf.int32, name=_TFINPUT1)
1349+
values, _ = tf.nn.top_k(x, k, sorted=True)
1350+
_ = tf.identity(values, name=_TFOUTPUT)
1351+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: k_val})
1352+
13431353
def test_stack_axis(self):
13441354
for axis in [0, 1]:
13451355
tf.reset_default_graph()

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def version_7(cls, ctx, node, **kwargs):
611611
pass
612612

613613

614-
@tf_op("TopKV2")
614+
@tf_op("TopKV2", onnx_op="TopK")
615615
class TopKV2:
616616
@classmethod
617617
def version_4(cls, ctx, node, **kwargs):
@@ -636,6 +636,15 @@ def version_4(cls, ctx, node, **kwargs):
636636
name=new_cast_name, attr={"to": onnx_pb.TensorProto.INT32},
637637
shapes=[shapes[1]], dtypes=[onnx_pb.TensorProto.INT32])
638638

639+
@classmethod
640+
def version_10(cls, ctx, node, **kwargs):
641+
# onnx only supports input K as a 1D tesor with dtype int64
642+
# while in tf, K is a 0D tensor with dtype int32
643+
k_0d = node.input[1]
644+
cast = ctx.make_node("Cast", [k_0d], attr={"to": onnx_pb.TensorProto.INT64})
645+
k_1d = ctx.make_node("Unsqueeze", cast.output, attr={"axes": [0]})
646+
ctx.replace_input(node, k_0d, k_1d.output[0])
647+
639648

640649
@tf_op("Tile")
641650
class Tile:

0 commit comments

Comments
 (0)