Skip to content

Commit 0b0c5ab

Browse files
authored
Merge pull request #1016 from xadupre/i1011atan2
Fixes #1011, add support for atan2
2 parents 85bca92 + acfdccf commit 0b0c5ab

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed

tests/test_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from itertools import product
1515

1616
import numpy as np
17+
from numpy.testing import assert_almost_equal
1718
import tensorflow as tf
1819

1920
from tensorflow.python.ops import lookup_ops
@@ -3352,6 +3353,29 @@ def func(base_matrix, diag, k):
33523353

33533354
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: diag_val, _INPUT2: k_val})
33543355

3356+
@check_opset_min_version(9, "atan2")
3357+
def test_atan2(self):
3358+
# Test all possible pairs of pos, neg, zero for x and y.
3359+
3360+
def atan2(y, x):
3361+
sx = np.sign(x)
3362+
sy = np.sign(y)
3363+
pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-np.pi/2)
3364+
atan_part = np.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
3365+
return atan_part + pi_part
3366+
3367+
test_pairs = [[y, x] for x in [3., -4., 0.] for y in [5., -6., 0.]]
3368+
y_val = np.array([y for y, x in test_pairs], dtype=np.float32)
3369+
x_val = np.array([x for y, x in test_pairs], dtype=np.float32)
3370+
assert_almost_equal(np.arctan2(y_val, x_val), atan2(y_val, x_val))
3371+
3372+
def func(y, x):
3373+
atan2_ = tf.math.atan2(y, x)
3374+
return tf.identity(atan2_, name=_TFOUTPUT)
3375+
3376+
self._run_test_case(
3377+
func, [_OUTPUT], {_INPUT: y_val, _INPUT2: x_val}, rtol=1e-06)
3378+
33553379

33563380
if __name__ == '__main__':
33573381
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,113 @@ def version_10(cls, ctx, node, **kwargs):
586586
shapes=shapes, dtypes=dtypes)
587587
_ = ctx.make_node("Not", inputs=or_node.output, name=node.name,
588588
shapes=shapes, dtypes=dtypes)
589+
590+
591+
@tf_op("Atan2")
592+
class Atan2Op:
593+
# support more dtype
594+
supported_dtypes = [
595+
onnx_pb.TensorProto.FLOAT,
596+
onnx_pb.TensorProto.FLOAT16,
597+
onnx_pb.TensorProto.DOUBLE
598+
]
599+
600+
@classmethod
601+
def version_9(cls, ctx, node, **kwargs):
602+
"""
603+
Obtained with a linear regression.
604+
605+
::
606+
607+
def atan2(y, x):
608+
sx = numpy.sign(x)
609+
sy = numpy.sign(y)
610+
pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-numpy.pi/2)
611+
atan_part = numpy.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
612+
return atan_part + pi_part
613+
"""
614+
615+
onnx_dtype = ctx.get_dtype(node.input[0])
616+
shape = ctx.get_shape(node.input[0])
617+
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
618+
619+
# sign part
620+
621+
sign_x_node = ctx.make_node(
622+
"Sign", inputs=node.input[1:],
623+
name=utils.make_name(node.name + 'signx'))
624+
sign_y_node = ctx.make_node(
625+
"Sign", inputs=node.input[:1],
626+
name=utils.make_name(node.name + 'signy'))
627+
628+
sx_node = ctx.make_node(
629+
"Cast", sign_x_node.output[:1], attr={"to": onnx_dtype},
630+
name=utils.make_name(node.name + 'csignx'))
631+
sy_node = ctx.make_node(
632+
"Cast", sign_y_node.output[:1], attr={"to": onnx_dtype},
633+
name=utils.make_name(node.name + 'csigny'))
634+
635+
# cst
636+
637+
one_node = ctx.make_const(
638+
utils.make_name("{}_one".format(node.name)),
639+
np.array([1], dtype=np_dtype))
640+
641+
pib2_node = ctx.make_const(
642+
utils.make_name("{}_pi".format(node.name)),
643+
np.array(- np.pi / 2, dtype=np_dtype))
644+
645+
# pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-numpy.pi/2)
646+
647+
sxm1_node = ctx.make_node(
648+
"Sub", [sx_node.output[0], one_node.output[0]],
649+
name=utils.make_name(node.name + 'sxm1'))
650+
sy2_node = ctx.make_node(
651+
"Mul", [sy_node.output[0], sy_node.output[0]],
652+
name=utils.make_name(node.name + 'sy2'))
653+
sy2m1_node = ctx.make_node(
654+
"Sub", [sy2_node.output[0], one_node.output[0]],
655+
name=utils.make_name(node.name + 'sy2m1'))
656+
sxsy2m1_node = ctx.make_node(
657+
"Mul", [sx_node.output[0], sy2m1_node.output[0]],
658+
name=utils.make_name(node.name + 'sxsy2m1'))
659+
sysxsy2m1_node = ctx.make_node(
660+
"Add", [sy_node.output[0], sxsy2m1_node.output[0]],
661+
name=utils.make_name(node.name + 'sysxsy2m1'))
662+
m1_node = ctx.make_node(
663+
"Mul", [sysxsy2m1_node.output[0], sxm1_node.output[0]],
664+
name=utils.make_name(node.name + 'm1'))
665+
pi_part = ctx.make_node(
666+
"Mul", [m1_node.output[0], pib2_node.output[0]],
667+
name=utils.make_name(node.name + 'pip'))
668+
669+
# atan
670+
671+
sx2_node = ctx.make_node(
672+
"Mul", [sx_node.output[0], sx_node.output[0]],
673+
name=utils.make_name(node.name + 'sx2'))
674+
sx2m1_node = ctx.make_node(
675+
"Sub", [sx2_node.output[0], one_node.output[0]],
676+
name=utils.make_name(node.name + 'sx2m1'))
677+
xsx2m1_node = ctx.make_node(
678+
"Add", [node.input[1], sx2m1_node.output[0]],
679+
name=utils.make_name(node.name + 'xsx2m1'))
680+
div_node = ctx.make_node(
681+
"Div", inputs=[node.input[0], xsx2m1_node.output[0]],
682+
name=utils.make_name(node.name + 'div'))
683+
atan0_node = ctx.make_node(
684+
"Atan", inputs=[div_node.output[0]],
685+
name=utils.make_name(node.name + 'atan0'))
686+
atan_node = ctx.make_node(
687+
"Mul", inputs=[sx2_node.output[0], atan0_node.output[0]],
688+
name=utils.make_name(node.name + 'atan'))
689+
690+
# final
691+
692+
ctx.remove_node(node.name)
693+
694+
last_node = ctx.make_node(
695+
"Add", inputs=[atan_node.output[0], pi_part.output[0]],
696+
op_name_scope=node.name + 'all',
697+
shapes=[shape], dtypes=[onnx_dtype])
698+
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])

0 commit comments

Comments
 (0)