Skip to content

Commit 0f1e942

Browse files
committed
arg min/max cast int32
1 parent bd38f61 commit 0f1e942

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

tests/test_backend.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,14 +1029,27 @@ def test_randomuniform_int(self):
10291029
# since results are random, compare the shapes only
10301030
self._run_test_case([_OUTPUT], {}, check_value=False, check_shape=True)
10311031

1032-
@unittest.skip("")
1032+
@skip_caffe2_backend()
10331033
def test_argminmax(self):
1034-
# TODO: fails on onnxmsrt caffe2
10351034
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
1036-
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1035+
x = tf.placeholder(x_val.dtype, x_val.shape, name=_TFINPUT)
10371036
x_ = tf.argmin(x, axis=0)
10381037
_ = tf.identity(x_, name=_TFOUTPUT)
10391038
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1039+
tf.reset_default_graph()
1040+
1041+
x_val = np.array([1, 2, -2, -1], dtype=np.int32).reshape((2, 2))
1042+
x = tf.placeholder(x_val.dtype, x_val.shape, name=_TFINPUT)
1043+
x_ = tf.argmax(x)
1044+
_ = tf.identity(x_, name=_TFOUTPUT)
1045+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1046+
tf.reset_default_graph()
1047+
1048+
x_val = np.array([1, 2, -2, -1], dtype=np.int32).reshape((2, 2))
1049+
x = tf.placeholder(x_val.dtype, x_val.shape, name=_TFINPUT)
1050+
x_ = tf.argmax(x, output_type=x_val.dtype)
1051+
_ = tf.identity(x_, name=_TFOUTPUT)
1052+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
10401053

10411054
@check_opset_min_version(6, "cast")
10421055
def test_cast(self):

tf2onnx/tfonnx.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def tflist_to_onnx(node_list, shape_override):
6464
# ignore the following attributes
6565
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
6666
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "dynamic_size", "Tmultiples",
67-
"output_dtype", "Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
67+
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
6868
"Tout", "Tlabels", "Tindex", "element_shape"]
6969
# some stats
7070
op_cnt = collections.Counter()
@@ -239,10 +239,22 @@ def arg_minmax_op(ctx, node, name, args):
239239
dim_count = len(input_shape) if input_shape else 0
240240
axis = dim_count + axis
241241

242+
nodes = [node]
243+
# TF ArgMin/ArgMax may return int32 or int64
244+
# Onnx ArgMin/ArgMax only supports int64 output, add cast if needed
245+
if node.get_attr_int("output_type") == onnx_pb.TensorProto.INT32:
246+
# current node will return int64 after conversion, which differs from previous dtype got from tf
247+
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.INT64)
248+
op_name = utils.make_name("Cast")
249+
cast_node = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name, to=onnx_pb.TensorProto.INT32)
250+
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT32)
251+
ctx.copy_shape(node.output[0], cast_node.output[0])
252+
nodes.append(cast_node)
253+
242254
node.set_attr("axis", axis)
243255
node.set_attr("keepdims", 0)
244256
ctx.remove_input(node, node.input[1])
245-
return node
257+
return nodes
246258

247259

248260
def reduce_op(ctx, node, name, args):
@@ -700,7 +712,7 @@ def biasadd_op7(ctx, node, name, args):
700712
shape0 = ctx.get_shape(node.input[0])
701713
shape1 = ctx.get_shape(node.input[1])
702714
if node.inputs[1].type == 'Const' and len(shape1) == 1:
703-
new_broadcast_shape = [shape1[0],] + [1,] * (len(shape0) - 2)
715+
new_broadcast_shape = [shape1[0]] + [1] * (len(shape0) - 2)
704716
shape_name = utils.make_name(node.name)
705717
shape_const_node = ctx.make_const(shape_name, np.array(new_broadcast_shape, dtype=np.int64))
706718
op_name = node.input[1]
@@ -1191,7 +1203,6 @@ def minmax_op(ctx, node, name, args):
11911203

11921204

11931205
def pack_op(ctx, node, name, args):
1194-
11951206
# hack to make up for the missing onnx pack op
11961207
axis = node.get_attr("axis").i
11971208
if axis < 0:
@@ -1333,14 +1344,18 @@ def matmul_op(ctx, node, name, args):
13331344
shape = ctx.get_shape(node.input[0])
13341345
if shape:
13351346
perm = list(range(0, len(shape)))
1336-
tmp = perm[-1]; perm[-1] = perm[-2]; perm[-2] = tmp
1347+
tmp = perm[-1]
1348+
perm[-1] = perm[-2]
1349+
perm[-2] = tmp
13371350
transpose = ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=perm)
13381351
nodes.insert(0, transpose)
13391352
if transpose_b != 0:
13401353
shape = ctx.get_shape(node.input[1])
13411354
if shape:
13421355
perm = list(range(0, len(shape)))
1343-
tmp = perm[-1]; perm[-1] = perm[-2]; perm[-2] = tmp
1356+
tmp = perm[-1]
1357+
perm[-1] = perm[-2]
1358+
perm[-2] = tmp
13441359
transpose = ctx.insert_new_node_on_input(node, "Transpose", node.input[1], perm=perm)
13451360
nodes.insert(0, transpose)
13461361

0 commit comments

Comments
 (0)