diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 3af5ebbe19..88fd1acdb7 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -742,9 +742,6 @@ def astype( order=copy_order, buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, ) - # see #2121 - if ary_dtype == dpt.bool: - usm_ary = dpt.not_equal(usm_ary, 0, order=copy_order) _copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary) return R diff --git a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp index f19dcb7c8c..023c3d8717 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp @@ -101,7 +101,7 @@ class GenericCopyFunctor const ssize_t &src_offset = offsets.get_first_offset(); const ssize_t &dst_offset = offsets.get_second_offset(); - CastFnT fn{}; + static constexpr CastFnT fn{}; dst_[dst_offset] = fn(src_[src_offset]); } }; @@ -237,9 +237,9 @@ class ContigCopyFunctor static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; - using dpctl::tensor::type_utils::is_complex; - if constexpr (!enable_sg_loadstore || is_complex::value || - is_complex::value) + using dpctl::tensor::type_utils::is_complex_v; + if constexpr (!enable_sg_loadstore || is_complex_v || + is_complex_v) { std::uint16_t sgSize = ndit.get_sub_group().get_local_range()[0]; const std::size_t gid = ndit.get_global_linear_id(); diff --git a/dpctl/tensor/libtensor/include/utils/type_utils.hpp b/dpctl/tensor/libtensor/include/utils/type_utils.hpp index 41c42476b6..4921659166 100644 --- a/dpctl/tensor/libtensor/include/utils/type_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_utils.hpp @@ -25,6 +25,7 @@ #pragma once #include #include +#include #include #include #include @@ -55,26 +56,45 @@ template inline constexpr bool is_complex_v = is_complex::value; template dstTy convert_impl(const srcTy &v) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same_v) { return v; } - else if constexpr (std::is_same_v && is_complex::value) - { - // bool(complex_v) == (complex_v.real() != 0) && (complex_v.imag() !=0) - return (convert_impl(v.real()) || - convert_impl(v.imag())); + else if constexpr (std::is_same_v) { + if constexpr (is_complex_v) { + // bool(complex_v) == + // (complex_v.real() != 0) && (complex_v.imag() !=0) + return (convert_impl(v.real()) || + convert_impl(v.imag())); + } + else { + return static_cast(v != srcTy{0}); + } + } + else if constexpr (std::is_same_v) { + // C++ interprets a byte of storage behind bool by only + // testing is least significant bit, leading to both + // 0x00 and 0x02 interpreted as False, while 0x01 and 0xFF + // interpreted as True. NumPy's interpretation of underlying + // storage is different: any bit set is interpreted as True, + // no bits set as False, see gh-2121 + const std::uint8_t &u = sycl::bit_cast(v); + if constexpr (is_complex_v) { + return (u == 0) ? dstTy{} : dstTy{1, 0}; + } + else { + return (u == 0) ? dstTy{} : dstTy{1}; + } } - else if constexpr (is_complex::value && !is_complex::value) { + else if constexpr (is_complex_v && !is_complex_v) { // real_t(complex_v) == real_t(complex_v.real()) return convert_impl(v.real()); } - else if constexpr (!std::is_integral::value && - !std::is_same::value && - std::is_integral::value && - std::is_unsigned::value) + else if constexpr (!std::is_integral_v && + !std::is_same_v && + std::is_integral_v && std::is_unsigned_v) { // first cast to signed variant, the cast to unsigned one - using signedT = typename std::make_signed::type; + using signedT = typename std::make_signed_t; return static_cast(convert_impl(v)); } else {