Skip to content

Commit fe2a433

Browse files
Implement reduce all/any for non-const axes (#1657)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent e912e11 commit fe2a433

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

tests/test_backend.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3582,29 +3582,29 @@ def func(x):
35823582
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
35833583

35843584
def test_reduce_all(self):
3585-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3585+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
35863586
def func(x):
35873587
res = tf.reduce_all(input_tensor=x, keepdims=False)
35883588
res1 = tf.reduce_all(input_tensor=x, axis=[0], keepdims=False)
35893589
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
35903590
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
35913591

3592-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3592+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
35933593
def func(input_x):
35943594
res = tf.reduce_all(input_tensor=input_x, keepdims=True)
35953595
res1 = tf.reduce_all(input_tensor=input_x, axis=[0], keepdims=True)
35963596
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
35973597
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
35983598

35993599
def test_reduce_any(self):
3600-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3600+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
36013601
def func(x):
36023602
res = tf.reduce_any(input_tensor=x, keepdims=False)
36033603
res1 = tf.reduce_any(input_tensor=x, axis=[0], keepdims=False)
36043604
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
36053605
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
36063606

3607-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3607+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
36083608
def func(x):
36093609
res = tf.reduce_any(input_tensor=x, keepdims=True)
36103610
res1 = tf.reduce_any(input_tensor=x, axis=[0], keepdims=True)
@@ -3613,14 +3613,14 @@ def func(x):
36133613

36143614
@check_opset_min_version(11, "ReduceMin")
36153615
def test_reduce_all_negative_axis(self):
3616-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3616+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
36173617
def func(x):
36183618
res = tf.reduce_all(input_tensor=x, keepdims=False)
36193619
res1 = tf.reduce_all(input_tensor=x, axis=[-1], keepdims=False)
36203620
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
36213621
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
36223622

3623-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3623+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
36243624
def func(input_x):
36253625
res = tf.reduce_all(input_tensor=input_x, keepdims=True)
36263626
res1 = tf.reduce_all(input_tensor=input_x, axis=[-1], keepdims=True)
@@ -3629,14 +3629,14 @@ def func(input_x):
36293629

36303630
@check_opset_min_version(11, "ReduceSum")
36313631
def test_reduce_any_negative_axis(self):
3632-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3632+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
36333633
def func(x):
36343634
res = tf.reduce_any(input_tensor=x, keepdims=False)
36353635
res1 = tf.reduce_any(input_tensor=x, axis=[-1], keepdims=False)
36363636
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
36373637
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
36383638

3639-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3639+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
36403640
def func(x):
36413641
res = tf.reduce_any(input_tensor=x, keepdims=True)
36423642
res1 = tf.reduce_any(input_tensor=x, axis=[-1], keepdims=True)
@@ -3646,13 +3646,31 @@ def func(x):
36463646
@check_opset_min_version(11, "ReduceSum")
36473647
@check_tf_min_version("1.15")
36483648
def test_reduce_any_empty_axis(self):
3649-
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
3649+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
36503650
def func(x):
36513651
res = tf.reduce_any(input_tensor=x, keepdims=False)
36523652
res1 = tf.reduce_any(input_tensor=x, axis=[], keepdims=False)
36533653
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
36543654
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
36553655

3656+
def test_reduce_all_scalar_axis(self):
3657+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
3658+
def func(x):
3659+
res = tf.reduce_all(input_tensor=x, keepdims=False)
3660+
res1 = tf.reduce_all(input_tensor=x, axis=0, keepdims=False)
3661+
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
3662+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
3663+
3664+
@check_opset_min_version(13, "ReduceSum")
3665+
@check_tf_min_version("1.15")
3666+
def test_reduce_any_nonconst_axis(self):
3667+
input_val = np.random.randint(0, 2, (2, 20)).astype(np.bool)
3668+
y_val = np.array([1], np.int32)
3669+
def func(x, y):
3670+
res = tf.reduce_any(input_tensor=x, axis=y, keepdims=False)
3671+
return tf.identity(res, name=_TFOUTPUT)
3672+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: y_val})
3673+
36563674
@check_opset_min_version(7, "fill")
36573675
def test_zeros_like(self):
36583676
input_x = np.random.random_sample([10, 20]).astype(np.float32)

tf2onnx/onnx_opset/reduction.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,29 @@ def version_6(cls, ctx, node, **kwargs):
144144
ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]],
145145
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
146146

147+
@classmethod
148+
def version_13(cls, ctx, node, **kwargs):
149+
keepdims = node.get_attr_value('keep_dims')
150+
reduce_input = node.input[0]
151+
if node.type == "All":
152+
reduce_input = ctx.make_node("Not", [reduce_input]).output[0]
153+
cast = ctx.make_node("Cast", inputs=[reduce_input], attr={"to": onnx_pb.TensorProto.FLOAT}).output[0]
154+
axes_cast = node.input[1]
155+
if ctx.get_rank(axes_cast) == 0:
156+
# Unsqueeze scalar axes
157+
axes_cast = GraphBuilder(ctx).make_unsqueeze({'data': axes_cast, 'axes': [0]})
158+
if ctx.get_dtype(axes_cast) != onnx_pb.TensorProto.INT64:
159+
axes_cast = ctx.make_node("Cast", inputs=[axes_cast], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
160+
reduce_node_output = GraphBuilder(ctx).make_reduce_sum(
161+
{"data": cast, "axes": axes_cast, "keepdims": keepdims, "noop_with_empty_axes": 1},
162+
shapes=node.output_shapes, op_name_scope=node.name)
163+
zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32))
164+
greater_node = ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]])
165+
result = greater_node.output[0]
166+
if node.type == "All":
167+
result = ctx.make_node("Not", [result]).output[0]
168+
ctx.replace_all_inputs(node.output[0], result)
169+
147170

148171
@tf_op("AddN")
149172
class AddN():

0 commit comments

Comments
 (0)