Skip to content

Commit d18d3f7

Browse files
Implement Round for lower opsets (#1589)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent ec39a9f commit d18d3f7

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4349,6 +4349,14 @@ def func(x):
43494349
return tf.identity(x_, name=_TFOUTPUT)
43504350
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
43514351

4352+
def test_round_approx(self):
4353+
# In lower opsets there is no Round, but we can approximate it forgoing nearest even
4354+
x_val = np.array([-0.7, -0.5, -0.0, 0.0, +0.0, 0.3, 1.5, 0.7, float('nan')], dtype=np.float32)
4355+
def func(x):
4356+
x_ = tf.round(x)
4357+
return tf.identity(x_, name=_TFOUTPUT)
4358+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4359+
43524360
@check_opset_min_version(11, "Det")
43534361
@unittest.skip("unclear how this is called in tf-2, fix later")
43544362
def test_determinant(self):

tf2onnx/onnx_opset/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,15 @@ def version_11(cls, ctx, node, **kwargs):
535535

536536
@tf_op("Round")
537537
class Round:
538+
@classmethod
539+
def version_1(cls, ctx, node, **kwargs):
540+
# Not exactly nearest even but close enough
541+
np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))
542+
const_half = ctx.make_const(utils.make_name("const_half"), np.array(0.5, np_dtype)).output[0]
543+
add_node = ctx.make_node("Add", [node.input[0], const_half], op_name_scope=node.name).output[0]
544+
node.type = "Floor"
545+
ctx.replace_inputs(node, [add_node])
546+
538547
@classmethod
539548
def version_11(cls, ctx, node, **kwargs):
540549
pass

0 commit comments

Comments
 (0)