Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,9 +742,6 @@ def astype(
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
# see #2121
if ary_dtype == dpt.bool:
usm_ary = dpt.not_equal(usm_ary, 0, order=copy_order)
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
return R

Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class GenericCopyFunctor
const ssize_t &src_offset = offsets.get_first_offset();
const ssize_t &dst_offset = offsets.get_second_offset();

CastFnT fn{};
static constexpr CastFnT fn{};
dst_[dst_offset] = fn(src_[src_offset]);
}
};
Expand Down Expand Up @@ -237,9 +237,9 @@ class ContigCopyFunctor

static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;

using dpctl::tensor::type_utils::is_complex;
if constexpr (!enable_sg_loadstore || is_complex<srcT>::value ||
is_complex<dstT>::value)
using dpctl::tensor::type_utils::is_complex_v;
if constexpr (!enable_sg_loadstore || is_complex_v<srcT> ||
is_complex_v<dstT>)
{
std::uint16_t sgSize = ndit.get_sub_group().get_local_range()[0];
const std::size_t gid = ndit.get_global_linear_id();
Expand Down
38 changes: 26 additions & 12 deletions dpctl/tensor/libtensor/include/utils/type_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#pragma once
#include <complex>
#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include <sycl/sycl.hpp>
#include <type_traits>
Expand Down Expand Up @@ -55,26 +56,39 @@ template <typename T> inline constexpr bool is_complex_v = is_complex<T>::value;

template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
{
if constexpr (std::is_same<dstTy, srcTy>::value) {
if constexpr (std::is_same_v<dstTy, srcTy>) {
return v;
}
else if constexpr (std::is_same_v<dstTy, bool> && is_complex<srcTy>::value)
{
// bool(complex_v) == (complex_v.real() != 0) && (complex_v.imag() !=0)
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
convert_impl<bool, typename srcTy::value_type>(v.imag()));
else if constexpr (std::is_same_v<dstTy, bool>) {
if constexpr (is_complex_v<srcTy>) {
// bool(complex_v) ==
// (complex_v.real() != 0) && (complex_v.imag() !=0)
return (convert_impl<bool, typename srcTy::value_type>(v.real()) ||
convert_impl<bool, typename srcTy::value_type>(v.imag()));
}
else {
return static_cast<dstTy>(v != srcTy{0});
}
}
else if constexpr (std::is_same_v<srcTy, bool>) {
const std::uint8_t &u = sycl::bit_cast<std::uint8_t>(v);
if constexpr (is_complex_v<dstTy>) {
return (u == 0) ? dstTy{} : dstTy{1, 0};
}
else {
return (u == 0) ? dstTy{} : dstTy{1};
}
}
else if constexpr (is_complex<srcTy>::value && !is_complex<dstTy>::value) {
else if constexpr (is_complex_v<srcTy> && !is_complex_v<dstTy>) {
// real_t(complex_v) == real_t(complex_v.real())
return convert_impl<dstTy, typename srcTy::value_type>(v.real());
}
else if constexpr (!std::is_integral<srcTy>::value &&
!std::is_same<dstTy, bool>::value &&
std::is_integral<dstTy>::value &&
std::is_unsigned<dstTy>::value)
else if constexpr (!std::is_integral_v<srcTy> &&
!std::is_same_v<dstTy, bool> &&
std::is_integral_v<dstTy> && std::is_unsigned_v<dstTy>)
{
// first cast to signed variant, the cast to unsigned one
using signedT = typename std::make_signed<dstTy>::type;
using signedT = typename std::make_signed_t<dstTy>;
return static_cast<dstTy>(convert_impl<signedT, srcTy>(v));
}
else {
Expand Down
Loading