Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,9 +742,6 @@ def astype(
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
# see #2121
if ary_dtype == dpt.bool:
usm_ary = dpt.not_equal(usm_ary, 0, order=copy_order)
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
return R

Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class GenericCopyFunctor
const ssize_t &src_offset = offsets.get_first_offset();
const ssize_t &dst_offset = offsets.get_second_offset();

CastFnT fn{};
static constexpr CastFnT fn{};
dst_[dst_offset] = fn(src_[src_offset]);
}
};
Expand Down Expand Up @@ -237,9 +237,9 @@ class ContigCopyFunctor

static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;

using dpctl::tensor::type_utils::is_complex;
if constexpr (!enable_sg_loadstore || is_complex<srcT>::value ||
is_complex<dstT>::value)
using dpctl::tensor::type_utils::is_complex_v;
if constexpr (!enable_sg_loadstore || is_complex_v<srcT> ||
is_complex_v<dstT>)
{
std::uint16_t sgSize = ndit.get_sub_group().get_local_range()[0];
const std::size_t gid = ndit.get_global_linear_id();
Expand Down
44 changes: 32 additions & 12 deletions dpctl/tensor/libtensor/include/utils/type_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#pragma once
#include <complex>
#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include <sycl/sycl.hpp>
#include <type_traits>
Expand Down Expand Up @@ -55,26 +56,45 @@ template <typename T> inline constexpr bool is_complex_v = is_complex<T>::value;

template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
{
if constexpr (std::is_same<dstTy, srcTy>::value) {
if constexpr (std::is_same_v<dstTy, srcTy>) {
return v;
}
else if constexpr (std::is_same_v<dstTy, bool> && is_complex<srcTy>::value)
{
// bool(complex_v) == (complex_v.real() != 0) && (complex_v.imag() !=0)
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
convert_impl<bool, typename srcTy::value_type>(v.imag()));
else if constexpr (std::is_same_v<dstTy, bool>) {
if constexpr (is_complex_v<srcTy>) {
// bool(complex_v) ==
// (complex_v.real() != 0) && (complex_v.imag() !=0)
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
convert_impl<bool, typename srcTy::value_type>(v.imag()));
}
else {
return static_cast<dstTy>(v != srcTy{0});
}
}
else if constexpr (std::is_same_v<srcTy, bool>) {
// C++ interprets a byte of storage behind bool by only
// testing is least significant bit, leading to both
// 0x00 and 0x02 interpreted as False, while 0x01 and 0xFF
// interpreted as True. NumPy's interpretation of underlying
// storage is different: any bit set is interpreted as True,
// no bits set as False, see gh-2121
const std::uint8_t &u = sycl::bit_cast<std::uint8_t>(v);
if constexpr (is_complex_v<dstTy>) {
return (u == 0) ? dstTy{} : dstTy{1, 0};
}
else {
return (u == 0) ? dstTy{} : dstTy{1};
}
}
else if constexpr (is_complex<srcTy>::value && !is_complex<dstTy>::value) {
else if constexpr (is_complex_v<srcTy> && !is_complex_v<dstTy>) {
// real_t(complex_v) == real_t(complex_v.real())
return convert_impl<dstTy, typename srcTy::value_type>(v.real());
}
else if constexpr (!std::is_integral<srcTy>::value &&
!std::is_same<dstTy, bool>::value &&
std::is_integral<dstTy>::value &&
std::is_unsigned<dstTy>::value)
else if constexpr (!std::is_integral_v<srcTy> &&
!std::is_same_v<dstTy, bool> &&
std::is_integral_v<dstTy> && std::is_unsigned_v<dstTy>)
{
// first cast to signed variant, the cast to unsigned one
using signedT = typename std::make_signed<dstTy>::type;
using signedT = typename std::make_signed_t<dstTy>;
return static_cast<dstTy>(convert_impl<signedT, srcTy>(v));
}
else {
Expand Down
Loading