Skip to content

Commit 17d5aa4

Browse files
Aminsedpytorchmergebot
authored andcommitted
disable jiterator for complex tan and tanh (pytorch#165250)
Fixes pytorch#100842 Disable jiterator for complex tan and tanh kernels due to accuracy issues, matching the existing approach used for acos, acosh, asin, and asinh. Reverts to thrust implementation which provides better numerical accuracy. Pull Request resolved: pytorch#165250 Approved by: https://github.com/ezyang
1 parent cde81e9 commit 17d5aa4

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212

1313
namespace at::native {
1414

15-
#if AT_USE_JITERATOR()
15+
#if 0 && AT_USE_JITERATOR()
1616
constexpr char tan_name[] = "tan_impl";
1717
#endif
1818

1919
void tan_kernel_cuda(TensorIteratorBase& iter) {
2020
auto common_dtype = iter.common_dtype();
2121
if (at::isComplexType(common_dtype)) {
22-
#if AT_USE_JITERATOR()
22+
// Disabled due to accuracy issues
23+
#if 0 && AT_USE_JITERATOR()
2324
static const auto tan_string = jiterator_stringify(
2425
template <typename T> T tan_impl(T a) { return std::tan(a); });
2526
AT_DISPATCH_COMPLEX_TYPES_AND(

aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212

1313
namespace at::native {
1414

15-
#if AT_USE_JITERATOR()
15+
#if 0 && AT_USE_JITERATOR()
1616
constexpr char tanh_name[] = "tanh_impl";
1717
#endif
1818

1919
void tanh_kernel_cuda(TensorIteratorBase& iter) {
2020
auto common_dtype = iter.common_dtype();
2121
if (at::isComplexType(common_dtype)) {
22-
#if AT_USE_JITERATOR()
22+
// Disabled due to accuracy issues
23+
#if 0 && AT_USE_JITERATOR()
2324
static const auto tanh_string = jiterator_stringify(
2425
template <typename T> T tanh_impl(T a) { return std::tanh(a); });
2526
AT_DISPATCH_COMPLEX_TYPES_AND(

test/test_unary_ufuncs.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,48 @@ def test_abs_angle_complex_to_float(self, device, dtype):
773773
with self.assertRaises(AttributeError):
774774
torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
775775

776+
@onlyCUDA
777+
@dtypes(torch.complex64)
778+
def test_tan_complex_cuda_matches_numpy(self, device, dtype):
779+
# Focused accuracy check for complex tan on CUDA against NumPy reference
780+
# Includes values near tan singularities on the real axis
781+
eps = 1e-3
782+
specials = torch.tensor(
783+
[
784+
math.pi / 2 - eps,
785+
math.pi / 2 + eps,
786+
-math.pi / 2 - eps,
787+
-math.pi / 2 + eps,
788+
],
789+
device=device,
790+
dtype=torch.float32,
791+
)
792+
real = torch.randn(1024, device=device, dtype=torch.float32) * (2 * math.pi)
793+
imag = torch.randn(1024, device=device, dtype=torch.float32) * 5.0
794+
real = torch.cat([real, specials])
795+
imag = torch.cat(
796+
[
797+
imag,
798+
torch.linspace(
799+
-3,
800+
3,
801+
steps=specials.numel(),
802+
device=device,
803+
),
804+
]
805+
)
806+
z = torch.complex(real, imag).to(dtype)
807+
self.compare_with_numpy(torch.tan, np.tan, z)
808+
809+
@onlyCUDA
810+
@dtypes(torch.complex64)
811+
def test_tanh_complex_cuda_matches_numpy(self, device, dtype):
812+
# Focused accuracy check for complex tanh on CUDA against NumPy reference
813+
real = torch.randn(2048, device=device, dtype=torch.float32) * (2 * math.pi)
814+
imag = torch.randn(2048, device=device, dtype=torch.float32) * 5.0
815+
z = torch.complex(real, imag).to(dtype)
816+
self.compare_with_numpy(torch.tanh, np.tanh, z)
817+
776818
def check_internal_mem_overlap(
777819
self, inplace_op, num_inputs, dtype, device, expected_failure=False
778820
):

0 commit comments

Comments
 (0)