Skip to content

Commit b3c2eea

Browse files
[libc++] Constrain additional overloads of pow for complex harder
1 parent 3f8380f commit b3c2eea

File tree

2 files changed

+109
-3
lines changed

2 files changed

+109
-3
lines changed

libcxx/include/complex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,20 +1097,20 @@ inline _LIBCPP_HIDE_FROM_ABI complex<_Tp> pow(const complex<_Tp>& __x, const com
10971097
return std::exp(__y * std::log(__x));
10981098
}
10991099

1100-
template <class _Tp, class _Up>
1100+
template <class _Tp, class _Up, __enable_if_t<is_floating_point<_Tp>::value && is_floating_point<_Up>::value, int> = 0>
11011101
inline _LIBCPP_HIDE_FROM_ABI complex<typename __promote<_Tp, _Up>::type>
11021102
pow(const complex<_Tp>& __x, const complex<_Up>& __y) {
11031103
typedef complex<typename __promote<_Tp, _Up>::type> result_type;
11041104
return std::pow(result_type(__x), result_type(__y));
11051105
}
11061106

1107-
template <class _Tp, class _Up, __enable_if_t<is_arithmetic<_Up>::value, int> = 0>
1107+
template <class _Tp, class _Up, __enable_if_t<is_floating_point<_Tp>::value && is_arithmetic<_Up>::value, int> = 0>
11081108
inline _LIBCPP_HIDE_FROM_ABI complex<typename __promote<_Tp, _Up>::type> pow(const complex<_Tp>& __x, const _Up& __y) {
11091109
typedef complex<typename __promote<_Tp, _Up>::type> result_type;
11101110
return std::pow(result_type(__x), result_type(__y));
11111111
}
11121112

1113-
template <class _Tp, class _Up, __enable_if_t<is_arithmetic<_Tp>::value, int> = 0>
1113+
template <class _Tp, class _Up, __enable_if_t<is_arithmetic<_Tp>::value && is_floating_point<_Up>::value, int> = 0>
11141114
inline _LIBCPP_HIDE_FROM_ABI complex<typename __promote<_Tp, _Up>::type> pow(const _Tp& __x, const complex<_Up>& __y) {
11151115
typedef complex<typename __promote<_Tp, _Up>::type> result_type;
11161116
return std::pow(result_type(__x), result_type(__y));
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// <complex>
10+
11+
// template<class T, class U> complex<__promote<T, U>::type> pow(const complex<T>&, const U&);
12+
// template<class T, class U> complex<__promote<T, U>::type> pow(const complex<T>&, const complex<U>&);
13+
// template<class T, class U> complex<__promote<T, U>::type> pow(const T&, const complex<U>&);
14+
15+
// Test that these additional overloads are free from catching std::complex<non-floating-point>,
16+
// which is expected by several 3rd party libraries, see https://github.com/llvm/llvm-project/issues/109858.
17+
18+
#include <cassert>
19+
#include <cmath>
20+
#include <complex>
21+
#include <type_traits>
22+
23+
#include "test_macros.h"
24+
25+
namespace usr {
26+
struct usr_tag {};
27+
28+
template <class T, class U>
29+
TEST_CONSTEXPR
30+
typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
31+
(std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
32+
int>::type
33+
pow(const T&, const std::complex<U>&) {
34+
return std::is_same<T, usr_tag>::value ? 0 : 1;
35+
}
36+
37+
template <class T, class U>
38+
TEST_CONSTEXPR
39+
typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
40+
(std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
41+
int>::type
42+
pow(const std::complex<T>&, const U&) {
43+
return std::is_same<U, usr_tag>::value ? 2 : 3;
44+
}
45+
46+
template <class T, class U>
47+
TEST_CONSTEXPR
48+
typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
49+
(std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
50+
int>::type
51+
pow(const std::complex<T>&, const std::complex<U>&) {
52+
return std::is_same<T, usr_tag>::value ? 4 : 5;
53+
}
54+
} // namespace usr
55+
56+
int main(int, char**) {
57+
using std::pow;
58+
using usr::pow;
59+
60+
TEST_CONSTEXPR usr::usr_tag tag;
61+
TEST_CONSTEXPR_CXX14 const std::complex<usr::usr_tag> ctag;
62+
63+
assert(pow(tag, std::complex<float>(1.0f)) == 0);
64+
assert(pow(std::complex<float>(1.0f), tag) == 2);
65+
assert(pow(tag, std::complex<double>(1.0)) == 0);
66+
assert(pow(std::complex<double>(1.0), tag) == 2);
67+
assert(pow(tag, std::complex<long double>(1.0l)) == 0);
68+
assert(pow(std::complex<long double>(1.0l), tag) == 2);
69+
70+
assert(pow(1.0f, ctag) == 1);
71+
assert(pow(ctag, 1.0f) == 3);
72+
assert(pow(1.0, ctag) == 1);
73+
assert(pow(ctag, 1.0) == 3);
74+
assert(pow(1.0l, ctag) == 1);
75+
assert(pow(ctag, 1.0l) == 3);
76+
77+
assert(pow(ctag, std::complex<float>(1.0f)) == 4);
78+
assert(pow(std::complex<float>(1.0f), ctag) == 5);
79+
assert(pow(ctag, std::complex<double>(1.0)) == 4);
80+
assert(pow(std::complex<double>(1.0), ctag) == 5);
81+
assert(pow(ctag, std::complex<long double>(1.0l)) == 4);
82+
assert(pow(std::complex<long double>(1.0l), ctag) == 5);
83+
84+
#if TEST_STD_VER >= 11
85+
static_assert(pow(tag, std::complex<float>(1.0f)) == 0, "");
86+
static_assert(pow(std::complex<float>(1.0f), tag) == 2, "");
87+
static_assert(pow(tag, std::complex<double>(1.0)) == 0, "");
88+
static_assert(pow(std::complex<double>(1.0), tag) == 2, "");
89+
static_assert(pow(tag, std::complex<long double>(1.0l)) == 0, "");
90+
static_assert(pow(std::complex<long double>(1.0l), tag) == 2, "");
91+
92+
static_assert(pow(1.0f, ctag) == 1, "");
93+
static_assert(pow(ctag, 1.0f) == 3, "");
94+
static_assert(pow(1.0, ctag) == 1, "");
95+
static_assert(pow(ctag, 1.0) == 3, "");
96+
static_assert(pow(1.0l, ctag) == 1, "");
97+
static_assert(pow(ctag, 1.0l) == 3, "");
98+
99+
static_assert(pow(ctag, std::complex<float>(1.0f)) == 4, "");
100+
static_assert(pow(std::complex<float>(1.0f), ctag) == 5, "");
101+
static_assert(pow(ctag, std::complex<double>(1.0)) == 4, "");
102+
static_assert(pow(std::complex<double>(1.0), ctag) == 5, "");
103+
static_assert(pow(ctag, std::complex<long double>(1.0l)) == 4, "");
104+
static_assert(pow(std::complex<long double>(1.0l), ctag) == 5, "");
105+
#endif
106+
}

0 commit comments

Comments
 (0)