Skip to content

Commit dee6047

Browse files
committed
Reduce duplication by combining erf-like functors
1 parent b9406ed commit dee6047

File tree

3 files changed

+27
-69
lines changed

3 files changed

+27
-69
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/erf_funcs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
#include "erf_funcs.hpp"
3434
#include "kernels/elementwise_functions/erf.hpp"
35-
#include "kernels/elementwise_functions/erfc.hpp"
35+
// #include "kernels/elementwise_functions/erfc.hpp"
3636

3737
// utils extension header
3838
#include "ext/common.hpp"

dpnp/backend/kernels/elementwise_functions/erf.hpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,44 @@
3131

3232
namespace dpnp::kernels::erf
3333
{
34-
template <typename argT, typename Tp>
35-
struct ErfFunctor
34+
template <typename OpT, typename ArgT, typename ResT>
35+
struct BaseFunctor
3636
{
37-
// is function constant for given argT
37+
// is function constant for given ArgT
3838
using is_constant = typename std::false_type;
3939
// constant value, if constant
40-
// constexpr Tp constant_value = Tp{};
40+
// constexpr ResT constant_value = ResT{};
4141
// is function defined for sycl::vec
4242
using supports_vec = typename std::false_type;
43-
// do both argT and Tp support subgroup store/load operation
43+
// do both ArgT and ResT support subgroup store/load operation
4444
using supports_sg_loadstore = typename std::true_type;
4545

46-
Tp operator()(const argT &x) const
46+
ResT operator()(const ArgT &x) const
4747
{
48-
if constexpr (std::is_same_v<argT, sycl::half> &&
49-
std::is_same_v<Tp, float>) {
48+
if constexpr (std::is_same_v<ArgT, sycl::half> &&
49+
std::is_same_v<ResT, float>) {
5050
// cast sycl::half to float for accuracy reasons
51-
return sycl::erf(float(x));
51+
return OpT::apply(float(x));
5252
}
5353
else {
54-
return sycl::erf(x);
54+
return OpT::apply(x);
5555
}
5656
}
5757
};
58+
59+
#define MACRO_DEFINE_FUNCTOR(__name__, __f_name__) \
60+
struct __f_name__##Op \
61+
{ \
62+
template <typename Tp> \
63+
static Tp apply(const Tp &x) \
64+
{ \
65+
return sycl::__name__(x); \
66+
} \
67+
}; \
68+
\
69+
template <typename ArgT, typename ResT> \
70+
using __f_name__##Functor = BaseFunctor<__f_name__##Op, ArgT, ResT>;
71+
72+
MACRO_DEFINE_FUNCTOR(erf, Erf);
73+
MACRO_DEFINE_FUNCTOR(erfc, Erfc);
5874
} // namespace dpnp::kernels::erf

dpnp/backend/kernels/elementwise_functions/erfc.hpp

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)