Skip to content

Commit 2de4f06

Browse files
committed
Add a dedicated kernel for sycl::half input dtype
Add a comment to casting in ErfFunctor implementation
1 parent a860f6a commit 2de4f06

File tree

2 files changed

+17
-5
lines changed
  • dpnp/backend
    • extensions/ufunc/elementwise_functions
    • kernels/elementwise_functions

2 files changed

+17
-5
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/erf.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,15 @@ namespace td_ns = dpctl::tensor::type_dispatch;
6262
template <typename T>
6363
struct OutputType
6464
{
65-
using value_type =
66-
typename std::disjunction<td_ns::TypeMapResultEntry<T, float>,
67-
td_ns::TypeMapResultEntry<T, double>,
68-
td_ns::DefaultResultEntry<void>>::result_type;
65+
/**
66+
* scipy>=1.16 assumes a pair 'e->d', but dpnp 'e->f' without an extra
67+
* kernel 'e->d' (when fp64 supported) to reduce memory footprint
68+
*/
69+
using value_type = typename std::disjunction<
70+
td_ns::TypeMapResultEntry<T, sycl::half, float>,
71+
td_ns::TypeMapResultEntry<T, float>,
72+
td_ns::TypeMapResultEntry<T, double>,
73+
td_ns::DefaultResultEntry<void>>::result_type;
6974
};
7075

7176
using dpnp::kernels::erf::ErfFunctor;

dpnp/backend/kernels/elementwise_functions/erf.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@ struct ErfFunctor
4545

4646
Tp operator()(const argT &x) const
4747
{
48-
return sycl::erf(x);
48+
if constexpr (std::is_same_v<argT, sycl::half> &&
49+
std::is_same_v<Tp, float>) {
50+
// cast sycl::half to float for accuracy reasons
51+
return sycl::erf(float(x));
52+
}
53+
else {
54+
return sycl::erf(x);
55+
}
4956
}
5057
};
5158
} // namespace dpnp::kernels::erf

0 commit comments

Comments
 (0)