Skip to content

Commit 37e7b39

Browse files
committed
Fix
Signed-off-by: Justin Chu <[email protected]>
1 parent 94b1576 commit 37e7b39

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,9 @@ def aten_atan2(self: TFloat, other: TFloat) -> TFloat:
933933
slope = op.Div(self, other)
934934
atan = op.Atan(slope)
935935
zero = common_ops.constant(0.0, dtype=self.dtype)
936+
pi = common_ops.constant(_MATH_PI, dtype=self.dtype)
936937

937-
second_third_quadrant = op.Where(op.Greater(self, zero), atan + _MATH_PI, atan - _MATH_PI)
938+
second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi)
938939
result = op.Where(op.Less(other, zero), second_third_quadrant, atan)
939940

940941
# Map NaN to 0 to match PyTorch behavior

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def _where_input_wrangler(
578578
TorchLibOpInfo("asin", core_ops.aten_asin),
579579
TorchLibOpInfo("asinh", core_ops.aten_asinh),
580580
TorchLibOpInfo("atan", core_ops.aten_atan),
581-
TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}),
581+
TorchLibOpInfo("atan2", core_ops.aten_atan2),
582582
TorchLibOpInfo("atanh", core_ops.aten_atanh),
583583
TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip(
584584
matcher=lambda sample: isinstance(sample.input, (list, tuple)),

0 commit comments

Comments
 (0)