Skip to content

Commit a6dd869

Browse files
committed
Fix gh-2121 in convert_impl function
Since convert_impl input argument is a reference, implement NumPy's interpretation of bool (underlying byte has any bits set in it) to override C++'s interpretation (underlying byte has the first bit set). To allow such an intepretation to work correctly bool arguments must be passed by reference, rather than by value. Passing by value creates a copy where C++ masks higher bits out.
1 parent 4be75e8 commit a6dd869

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

dpctl/tensor/libtensor/include/utils/type_utils.hpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#pragma once
2626
#include <complex>
2727
#include <cstddef>
28+
#include <cstdint>
2829
#include <stdexcept>
2930
#include <sycl/sycl.hpp>
3031
#include <type_traits>
@@ -55,26 +56,39 @@ template <typename T> inline constexpr bool is_complex_v = is_complex<T>::value;
5556

5657
template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
5758
{
58-
if constexpr (std::is_same<dstTy, srcTy>::value) {
59+
if constexpr (std::is_same_v<dstTy, srcTy>) {
5960
return v;
6061
}
61-
else if constexpr (std::is_same_v<dstTy, bool> && is_complex<srcTy>::value)
62-
{
63-
// bool(complex_v) == (complex_v.real() != 0) && (complex_v.imag() !=0)
64-
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
65-
convert_impl<bool, typename srcTy::value_type>(v.imag()));
62+
else if constexpr (std::is_same_v<dstTy, bool>) {
63+
if constexpr (is_complex_v<srcTy>) {
64+
// bool(complex_v) ==
65+
// (complex_v.real() != 0) && (complex_v.imag() !=0)
66+
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
67+
convert_impl<bool, typename srcTy::value_type>(v.imag()));
68+
}
69+
else {
70+
return static_cast<dstTy>(v != srcTy{0});
71+
}
72+
}
73+
else if constexpr (std::is_same_v<srcTy, bool>) {
74+
const std::uint8_t &u = sycl::bit_cast<std::uint8_t>(v);
75+
if constexpr (is_complex_v<dstTy>) {
76+
return (u == 0) ? dstTy{} : dstTy{1, 0};
77+
}
78+
else {
79+
return (u == 0) ? dstTy{} : dstTy{1};
80+
}
6681
}
67-
else if constexpr (is_complex<srcTy>::value && !is_complex<dstTy>::value) {
82+
else if constexpr (is_complex_v<srcTy> && !is_complex_v<dstTy>) {
6883
// real_t(complex_v) == real_t(complex_v.real())
6984
return convert_impl<dstTy, typename srcTy::value_type>(v.real());
7085
}
71-
else if constexpr (!std::is_integral<srcTy>::value &&
72-
!std::is_same<dstTy, bool>::value &&
73-
std::is_integral<dstTy>::value &&
74-
std::is_unsigned<dstTy>::value)
86+
else if constexpr (!std::is_integral_v<srcTy> &&
87+
!std::is_same_v<dstTy, bool> &&
88+
std::is_integral_v<dstTy> && std::is_unsigned_v<dstTy>)
7589
{
7690
// first cast to signed variant, the cast to unsigned one
77-
using signedT = typename std::make_signed<dstTy>::type;
91+
using signedT = typename std::make_signed_t<dstTy>;
7892
return static_cast<dstTy>(convert_impl<signedT, srcTy>(v));
7993
}
8094
else {

0 commit comments

Comments
 (0)