|
31 | 31 |
|
32 | 32 | namespace dpnp::kernels::erf |
33 | 33 | { |
34 | | -template <typename argT, typename Tp> |
35 | | -struct ErfFunctor |
| 34 | +template <typename OpT, typename ArgT, typename ResT> |
| 35 | +struct BaseFunctor |
36 | 36 | { |
37 | | - // is function constant for given argT |
| 37 | + // is function constant for given ArgT |
38 | 38 | using is_constant = typename std::false_type; |
39 | 39 | // constant value, if constant |
40 | | - // constexpr Tp constant_value = Tp{}; |
| 40 | + // constexpr ResT constant_value = ResT{}; |
41 | 41 | // is function defined for sycl::vec |
42 | 42 | using supports_vec = typename std::false_type; |
43 | | - // do both argT and Tp support subgroup store/load operation |
| 43 | + // do both ArgT and ResT support subgroup store/load operation |
44 | 44 | using supports_sg_loadstore = typename std::true_type; |
45 | 45 |
|
46 | | - Tp operator()(const argT &x) const |
| 46 | + ResT operator()(const ArgT &x) const |
47 | 47 | { |
48 | | - if constexpr (std::is_same_v<argT, sycl::half> && |
49 | | - std::is_same_v<Tp, float>) { |
| 48 | + if constexpr (std::is_same_v<ArgT, sycl::half> && |
| 49 | + std::is_same_v<ResT, float>) { |
50 | 50 | // cast sycl::half to float for accuracy reasons |
51 | | - return sycl::erf(float(x)); |
| 51 | + return OpT::apply(float(x)); |
52 | 52 | } |
53 | 53 | else { |
54 | | - return sycl::erf(x); |
| 54 | + return OpT::apply(x); |
55 | 55 | } |
56 | 56 | } |
57 | 57 | }; |
| 58 | + |
| 59 | +#define MACRO_DEFINE_FUNCTOR(__name__, __f_name__) \ |
| 60 | + struct __f_name__##Op \ |
| 61 | + { \ |
| 62 | + template <typename Tp> \ |
| 63 | + static Tp apply(const Tp &x) \ |
| 64 | + { \ |
| 65 | + return sycl::__name__(x); \ |
| 66 | + } \ |
| 67 | + }; \ |
| 68 | + \ |
| 69 | + template <typename ArgT, typename ResT> \ |
| 70 | + using __f_name__##Functor = BaseFunctor<__f_name__##Op, ArgT, ResT>; |
| 71 | + |
| 72 | +MACRO_DEFINE_FUNCTOR(erf, Erf); |
| 73 | +MACRO_DEFINE_FUNCTOR(erfc, Erfc); |
58 | 74 | } // namespace dpnp::kernels::erf |
0 commit comments