Skip to content

Commit 55546cf

Browse files
authored
Merge pull request #2158 from sycloid/revert-gh-2122-fix-for-gh-2121
Revert gh 2122 fix for gh 2121
2 parents d16d31b + 6ab3c71 commit 55546cf

File tree

3 files changed

+36
-19
lines changed

3 files changed

+36
-19
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,9 +742,6 @@ def astype(
742742
order=copy_order,
743743
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
744744
)
745-
# see #2121
746-
if ary_dtype == dpt.bool:
747-
usm_ary = dpt.not_equal(usm_ary, 0, order=copy_order)
748745
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
749746
return R
750747

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class GenericCopyFunctor
101101
const ssize_t &src_offset = offsets.get_first_offset();
102102
const ssize_t &dst_offset = offsets.get_second_offset();
103103

104-
CastFnT fn{};
104+
static constexpr CastFnT fn{};
105105
dst_[dst_offset] = fn(src_[src_offset]);
106106
}
107107
};
@@ -237,9 +237,9 @@ class ContigCopyFunctor
237237

238238
static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
239239

240-
using dpctl::tensor::type_utils::is_complex;
241-
if constexpr (!enable_sg_loadstore || is_complex<srcT>::value ||
242-
is_complex<dstT>::value)
240+
using dpctl::tensor::type_utils::is_complex_v;
241+
if constexpr (!enable_sg_loadstore || is_complex_v<srcT> ||
242+
is_complex_v<dstT>)
243243
{
244244
std::uint16_t sgSize = ndit.get_sub_group().get_local_range()[0];
245245
const std::size_t gid = ndit.get_global_linear_id();

dpctl/tensor/libtensor/include/utils/type_utils.hpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#pragma once
2626
#include <complex>
2727
#include <cstddef>
28+
#include <cstdint>
2829
#include <stdexcept>
2930
#include <sycl/sycl.hpp>
3031
#include <type_traits>
@@ -55,26 +56,45 @@ template <typename T> inline constexpr bool is_complex_v = is_complex<T>::value;
5556

5657
template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
5758
{
58-
if constexpr (std::is_same<dstTy, srcTy>::value) {
59+
if constexpr (std::is_same_v<dstTy, srcTy>) {
5960
return v;
6061
}
61-
else if constexpr (std::is_same_v<dstTy, bool> && is_complex<srcTy>::value)
62-
{
63-
// bool(complex_v) == (complex_v.real() != 0) && (complex_v.imag() !=0)
64-
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
65-
convert_impl<bool, typename srcTy::value_type>(v.imag()));
62+
else if constexpr (std::is_same_v<dstTy, bool>) {
63+
if constexpr (is_complex_v<srcTy>) {
64+
// bool(complex_v) ==
65+
// (complex_v.real() != 0) && (complex_v.imag() !=0)
66+
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
67+
convert_impl<bool, typename srcTy::value_type>(v.imag()));
68+
}
69+
else {
70+
return static_cast<dstTy>(v != srcTy{0});
71+
}
72+
}
73+
else if constexpr (std::is_same_v<srcTy, bool>) {
74+
// C++ interprets a byte of storage behind bool by only
75+
// testing is least significant bit, leading to both
76+
// 0x00 and 0x02 interpreted as False, while 0x01 and 0xFF
77+
// interpreted as True. NumPy's interpretation of underlying
78+
// storage is different: any bit set is interpreted as True,
79+
// no bits set as False, see gh-2121
80+
const std::uint8_t &u = sycl::bit_cast<std::uint8_t>(v);
81+
if constexpr (is_complex_v<dstTy>) {
82+
return (u == 0) ? dstTy{} : dstTy{1, 0};
83+
}
84+
else {
85+
return (u == 0) ? dstTy{} : dstTy{1};
86+
}
6687
}
67-
else if constexpr (is_complex<srcTy>::value && !is_complex<dstTy>::value) {
88+
else if constexpr (is_complex_v<srcTy> && !is_complex_v<dstTy>) {
6889
// real_t(complex_v) == real_t(complex_v.real())
6990
return convert_impl<dstTy, typename srcTy::value_type>(v.real());
7091
}
71-
else if constexpr (!std::is_integral<srcTy>::value &&
72-
!std::is_same<dstTy, bool>::value &&
73-
std::is_integral<dstTy>::value &&
74-
std::is_unsigned<dstTy>::value)
92+
else if constexpr (!std::is_integral_v<srcTy> &&
93+
!std::is_same_v<dstTy, bool> &&
94+
std::is_integral_v<dstTy> && std::is_unsigned_v<dstTy>)
7595
{
7696
// first cast to signed variant, the cast to unsigned one
77-
using signedT = typename std::make_signed<dstTy>::type;
97+
using signedT = typename std::make_signed_t<dstTy>;
7898
return static_cast<dstTy>(convert_impl<signedT, srcTy>(v));
7999
}
80100
else {

0 commit comments

Comments
 (0)