@@ -586,3 +586,113 @@ def version_10(cls, ctx, node, **kwargs):
586
586
shapes = shapes , dtypes = dtypes )
587
587
_ = ctx .make_node ("Not" , inputs = or_node .output , name = node .name ,
588
588
shapes = shapes , dtypes = dtypes )
589
+
590
+
591
+ @tf_op ("Atan2" )
592
+ class Atan2Op :
593
+ # support more dtype
594
+ supported_dtypes = [
595
+ onnx_pb .TensorProto .FLOAT ,
596
+ onnx_pb .TensorProto .FLOAT16 ,
597
+ onnx_pb .TensorProto .DOUBLE
598
+ ]
599
+
600
+ @classmethod
601
+ def version_9 (cls , ctx , node , ** kwargs ):
602
+ """
603
+ Obtained with a linear regression.
604
+
605
+ ::
606
+
607
+ def atan2(y, x):
608
+ sx = numpy.sign(x)
609
+ sy = numpy.sign(y)
610
+ pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-numpy.pi/2)
611
+ atan_part = numpy.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
612
+ return atan_part + pi_part
613
+ """
614
+
615
+ onnx_dtype = ctx .get_dtype (node .input [0 ])
616
+ shape = ctx .get_shape (node .input [0 ])
617
+ np_dtype = utils .map_onnx_to_numpy_type (onnx_dtype )
618
+
619
+ # sign part
620
+
621
+ sign_x_node = ctx .make_node (
622
+ "Sign" , inputs = node .input [1 :],
623
+ name = utils .make_name (node .name + 'signx' ))
624
+ sign_y_node = ctx .make_node (
625
+ "Sign" , inputs = node .input [:1 ],
626
+ name = utils .make_name (node .name + 'signy' ))
627
+
628
+ sx_node = ctx .make_node (
629
+ "Cast" , sign_x_node .output [:1 ], attr = {"to" : onnx_dtype },
630
+ name = utils .make_name (node .name + 'csignx' ))
631
+ sy_node = ctx .make_node (
632
+ "Cast" , sign_y_node .output [:1 ], attr = {"to" : onnx_dtype },
633
+ name = utils .make_name (node .name + 'csigny' ))
634
+
635
+ # cst
636
+
637
+ one_node = ctx .make_const (
638
+ utils .make_name ("{}_one" .format (node .name )),
639
+ np .array ([1 ], dtype = np_dtype ))
640
+
641
+ pib2_node = ctx .make_const (
642
+ utils .make_name ("{}_pi" .format (node .name )),
643
+ np .array (- np .pi / 2 , dtype = np_dtype ))
644
+
645
+ # pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-numpy.pi/2)
646
+
647
+ sxm1_node = ctx .make_node (
648
+ "Sub" , [sx_node .output [0 ], one_node .output [0 ]],
649
+ name = utils .make_name (node .name + 'sxm1' ))
650
+ sy2_node = ctx .make_node (
651
+ "Mul" , [sy_node .output [0 ], sy_node .output [0 ]],
652
+ name = utils .make_name (node .name + 'sy2' ))
653
+ sy2m1_node = ctx .make_node (
654
+ "Sub" , [sy2_node .output [0 ], one_node .output [0 ]],
655
+ name = utils .make_name (node .name + 'sy2m1' ))
656
+ sxsy2m1_node = ctx .make_node (
657
+ "Mul" , [sx_node .output [0 ], sy2m1_node .output [0 ]],
658
+ name = utils .make_name (node .name + 'sxsy2m1' ))
659
+ sysxsy2m1_node = ctx .make_node (
660
+ "Add" , [sy_node .output [0 ], sxsy2m1_node .output [0 ]],
661
+ name = utils .make_name (node .name + 'sysxsy2m1' ))
662
+ m1_node = ctx .make_node (
663
+ "Mul" , [sysxsy2m1_node .output [0 ], sxm1_node .output [0 ]],
664
+ name = utils .make_name (node .name + 'm1' ))
665
+ pi_part = ctx .make_node (
666
+ "Mul" , [m1_node .output [0 ], pib2_node .output [0 ]],
667
+ name = utils .make_name (node .name + 'pip' ))
668
+
669
+ # atan
670
+
671
+ sx2_node = ctx .make_node (
672
+ "Mul" , [sx_node .output [0 ], sx_node .output [0 ]],
673
+ name = utils .make_name (node .name + 'sx2' ))
674
+ sx2m1_node = ctx .make_node (
675
+ "Sub" , [sx2_node .output [0 ], one_node .output [0 ]],
676
+ name = utils .make_name (node .name + 'sx2m1' ))
677
+ xsx2m1_node = ctx .make_node (
678
+ "Add" , [node .input [1 ], sx2m1_node .output [0 ]],
679
+ name = utils .make_name (node .name + 'xsx2m1' ))
680
+ div_node = ctx .make_node (
681
+ "Div" , inputs = [node .input [0 ], xsx2m1_node .output [0 ]],
682
+ name = utils .make_name (node .name + 'div' ))
683
+ atan0_node = ctx .make_node (
684
+ "Atan" , inputs = [div_node .output [0 ]],
685
+ name = utils .make_name (node .name + 'atan0' ))
686
+ atan_node = ctx .make_node (
687
+ "Mul" , inputs = [sx2_node .output [0 ], atan0_node .output [0 ]],
688
+ name = utils .make_name (node .name + 'atan' ))
689
+
690
+ # final
691
+
692
+ ctx .remove_node (node .name )
693
+
694
+ last_node = ctx .make_node (
695
+ "Add" , inputs = [atan_node .output [0 ], pi_part .output [0 ]],
696
+ op_name_scope = node .name + 'all' ,
697
+ shapes = [shape ], dtypes = [onnx_dtype ])
698
+ ctx .replace_all_inputs (ctx .get_nodes (), node .output [0 ], last_node .output [0 ])
0 commit comments