Skip to content

Commit 9cabb52

Browse files
committed
support topk-10
1 parent 9713ffd commit 9cabb52

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

tests/test_backend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1333,13 +1333,23 @@ def test_cancel_transpose(self):
13331333
self._run_test_case([_OUTPUT], {_INPUT: x_val})
13341334

13351335
@check_opset_min_version(6, "cast")
1336-
def test_topk(self):
1336+
def test_topk1(self):
13371337
x_val = np.arange(3 * 2 * 3).astype("float32")
13381338
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
13391339
values, _ = tf.nn.top_k(x, 5, sorted=True)
13401340
_ = tf.identity(values, name=_TFOUTPUT)
13411341
self._run_test_case([_OUTPUT], {_INPUT: x_val})
13421342

1343+
@check_opset_min_version(10, "TopK with dynamic K")
1344+
def test_topk2(self):
1345+
x_val = np.arange(3 * 2 * 3).astype("float32")
1346+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1347+
k_val = np.array(10).astype(np.int32)
1348+
k = tf.placeholder(tf.int32, name=_TFINPUT1)
1349+
values, _ = tf.nn.top_k(x, k, sorted=True)
1350+
_ = tf.identity(values, name=_TFOUTPUT)
1351+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: k_val})
1352+
13431353
def test_stack_axis(self):
13441354
for axis in [0, 1]:
13451355
tf.reset_default_graph()

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def version_7(cls, ctx, node, **kwargs):
609609
pass
610610

611611

612-
@tf_op("TopKV2")
612+
@tf_op("TopKV2", onnx_op="TopK")
613613
class TopKV2:
614614
@classmethod
615615
def version_4(cls, ctx, node, **kwargs):
@@ -634,6 +634,15 @@ def version_4(cls, ctx, node, **kwargs):
634634
name=new_cast_name, attr={"to": onnx_pb.TensorProto.INT32},
635635
shapes=[shapes[1]], dtypes=[onnx_pb.TensorProto.INT32])
636636

637+
@classmethod
638+
def version_10(cls, ctx, node, **kwargs):
639+
# onnx only supports input K as a 1D tesor with dtype int64
640+
# while in tf, K is a 0D tensor with dtype int32
641+
k_0d = node.input[1]
642+
cast = ctx.make_node("Cast", [k_0d], attr={"to": onnx_pb.TensorProto.INT64})
643+
k_1d = ctx.make_node("Unsqueeze", cast.output, attr={"axes": [0]})
644+
ctx.replace_input(node, k_0d, k_1d.output[0])
645+
637646

638647
@tf_op("Tile")
639648
class Tile:

0 commit comments

Comments
 (0)