Skip to content
This repository was archived by the owner on Jan 26, 2024. It is now read-only.

Commit 20cc63e

Browse files
committed
Only overload vector operators for convertible types
This is necessary to avoid errors about ambiguous overloads in code such as ```` struct Point { float4 pos; float mass; }; template<typename T> Point operator+(Point const& p, T const& v) { return Point{p.pos + v, p.mass}; } int main() { float4 v = make_float4(0, 1, 2, 3); Point p{make_float4(3, 2, 1, 0), 1.0f}; Point q = p + v; } ```` when building with the host compiler. Closes hipamd issue #4
1 parent a79708e commit 20cc63e

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

include/hip/amd_detail/amd_hip_vector_types.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,13 @@ typedef basic_ostream<char> ostream;
939939
}
940940
};
941941

942+
// comfort type to only enable an operator if U can be converted to
943+
// a HIP_vector_type<T, N>
944+
template<typename U, typename T, unsigned int n,
945+
typename R=HIP_vector_type<T, n> /* operator return value */>
946+
using enable_if_convertible = typename
947+
std::enable_if<std::is_convertible<U, HIP_vector_type<T, n>>::value, R>::type;
948+
942949
template<typename T, unsigned int n>
943950
__HOST_DEVICE__
944951
inline
@@ -952,7 +959,7 @@ typedef basic_ostream<char> ostream;
952959
__HOST_DEVICE__
953960
inline
954961
constexpr
955-
HIP_vector_type<T, n> operator+(
962+
enable_if_convertible<U, T, n> operator+(
956963
const HIP_vector_type<T, n>& x, U y) noexcept
957964
{
958965
return HIP_vector_type<T, n>{x} += HIP_vector_type<T, n>{y};
@@ -961,7 +968,7 @@ typedef basic_ostream<char> ostream;
961968
__HOST_DEVICE__
962969
inline
963970
constexpr
964-
HIP_vector_type<T, n> operator+(
971+
enable_if_convertible<U, T, n> operator+(
965972
U x, const HIP_vector_type<T, n>& y) noexcept
966973
{
967974
return HIP_vector_type<T, n>{x} += y;
@@ -980,7 +987,7 @@ typedef basic_ostream<char> ostream;
980987
__HOST_DEVICE__
981988
inline
982989
constexpr
983-
HIP_vector_type<T, n> operator-(
990+
enable_if_convertible<U, T, n> operator-(
984991
const HIP_vector_type<T, n>& x, U y) noexcept
985992
{
986993
return HIP_vector_type<T, n>{x} -= HIP_vector_type<T, n>{y};
@@ -989,7 +996,7 @@ typedef basic_ostream<char> ostream;
989996
__HOST_DEVICE__
990997
inline
991998
constexpr
992-
HIP_vector_type<T, n> operator-(
999+
enable_if_convertible<U, T, n> operator-(
9931000
U x, const HIP_vector_type<T, n>& y) noexcept
9941001
{
9951002
return HIP_vector_type<T, n>{x} -= y;
@@ -1008,7 +1015,7 @@ typedef basic_ostream<char> ostream;
10081015
__HOST_DEVICE__
10091016
inline
10101017
constexpr
1011-
HIP_vector_type<T, n> operator*(
1018+
enable_if_convertible<U, T, n> operator*(
10121019
const HIP_vector_type<T, n>& x, U y) noexcept
10131020
{
10141021
return HIP_vector_type<T, n>{x} *= HIP_vector_type<T, n>{y};
@@ -1017,7 +1024,7 @@ typedef basic_ostream<char> ostream;
10171024
__HOST_DEVICE__
10181025
inline
10191026
constexpr
1020-
HIP_vector_type<T, n> operator*(
1027+
enable_if_convertible<U, T, n> operator*(
10211028
U x, const HIP_vector_type<T, n>& y) noexcept
10221029
{
10231030
return HIP_vector_type<T, n>{x} *= y;
@@ -1036,7 +1043,7 @@ typedef basic_ostream<char> ostream;
10361043
__HOST_DEVICE__
10371044
inline
10381045
constexpr
1039-
HIP_vector_type<T, n> operator/(
1046+
enable_if_convertible<U, T, n> operator/(
10401047
const HIP_vector_type<T, n>& x, U y) noexcept
10411048
{
10421049
return HIP_vector_type<T, n>{x} /= HIP_vector_type<T, n>{y};
@@ -1045,7 +1052,7 @@ typedef basic_ostream<char> ostream;
10451052
__HOST_DEVICE__
10461053
inline
10471054
constexpr
1048-
HIP_vector_type<T, n> operator/(
1055+
enable_if_convertible<U, T, n> operator/(
10491056
U x, const HIP_vector_type<T, n>& y) noexcept
10501057
{
10511058
return HIP_vector_type<T, n>{x} /= y;
@@ -1074,15 +1081,15 @@ typedef basic_ostream<char> ostream;
10741081
__HOST_DEVICE__
10751082
inline
10761083
constexpr
1077-
bool operator==(const HIP_vector_type<T, n>& x, U y) noexcept
1084+
enable_if_convertible<U, T, n, bool> operator==(const HIP_vector_type<T, n>& x, U y) noexcept
10781085
{
10791086
return x == HIP_vector_type<T, n>{y};
10801087
}
10811088
template<typename T, unsigned int n, typename U>
10821089
__HOST_DEVICE__
10831090
inline
10841091
constexpr
1085-
bool operator==(U x, const HIP_vector_type<T, n>& y) noexcept
1092+
enable_if_convertible<U, T, n, bool> operator==(U x, const HIP_vector_type<T, n>& y) noexcept
10861093
{
10871094
return HIP_vector_type<T, n>{x} == y;
10881095
}
@@ -1100,15 +1107,15 @@ typedef basic_ostream<char> ostream;
11001107
__HOST_DEVICE__
11011108
inline
11021109
constexpr
1103-
bool operator!=(const HIP_vector_type<T, n>& x, U y) noexcept
1110+
enable_if_convertible<U, T, n, bool> operator!=(const HIP_vector_type<T, n>& x, U y) noexcept
11041111
{
11051112
return !(x == y);
11061113
}
11071114
template<typename T, unsigned int n, typename U>
11081115
__HOST_DEVICE__
11091116
inline
11101117
constexpr
1111-
bool operator!=(U x, const HIP_vector_type<T, n>& y) noexcept
1118+
enable_if_convertible<U, T, n, bool> operator!=(U x, const HIP_vector_type<T, n>& y) noexcept
11121119
{
11131120
return !(x == y);
11141121
}

0 commit comments

Comments
 (0)