Skip to content

Commit 8db6cfc

Browse files
committed
Reduce code duplication by adding a macro to populate unary ufuncs with two output arrays
1 parent 74c40be commit 8db6cfc

File tree

3 files changed

+113
-232
lines changed

3 files changed

+113
-232
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/frexp.cpp

Lines changed: 5 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include "frexp.hpp"
3939
#include "kernels/elementwise_functions/frexp.hpp"
40+
#include "populate.hpp"
4041

4142
// include a local copy of elementwise common header from dpctl tensor:
4243
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
@@ -65,10 +66,9 @@ namespace td_int_ns = py_int::type_dispatch;
6566
namespace td_ns = dpctl::tensor::type_dispatch;
6667

6768
using dpnp::kernels::frexp::FrexpFunctor;
68-
using ext::common::init_dispatch_vector;
6969

7070
template <typename T>
71-
struct FrexpOutputType
71+
struct OutputType
7272
{
7373
using table_type = std::disjunction< // disjunction is C++17
7474
// feature, supported by DPC++
@@ -81,15 +81,13 @@ struct FrexpOutputType
8181
using value_type2 = typename table_type::result_type2;
8282
};
8383

84-
// contiguous implementation
85-
8684
template <typename argTy,
8785
typename resTy1 = argTy,
8886
typename resTy2 = argTy,
8987
std::uint8_t vec_sz = 4u,
9088
std::uint8_t n_vecs = 2u,
9189
bool enable_sg_loadstore = true>
92-
using FrexpContigFunctor =
90+
using ContigFunctor =
9391
ew_cmn_ns::UnaryTwoOutputsContigFunctor<argTy,
9492
resTy1,
9593
resTy2,
@@ -98,115 +96,14 @@ using FrexpContigFunctor =
9896
n_vecs,
9997
enable_sg_loadstore>;
10098

101-
// strided implementation
102-
10399
template <typename argTy, typename resTy1, typename resTy2, typename IndexerT>
104-
using FrexpStridedFunctor = ew_cmn_ns::UnaryTwoOutputsStridedFunctor<
100+
using StridedFunctor = ew_cmn_ns::UnaryTwoOutputsStridedFunctor<
105101
argTy,
106102
resTy1,
107103
resTy2,
108104
IndexerT,
109105
FrexpFunctor<argTy, resTy1, resTy2>>;
110106

111-
template <typename T1,
112-
typename T2,
113-
typename T3,
114-
unsigned int vec_sz,
115-
unsigned int n_vecs>
116-
class frexp_contig_kernel;
117-
118-
template <typename argTy>
119-
sycl::event frexp_contig_impl(sycl::queue &exec_q,
120-
size_t nelems,
121-
const char *arg_p,
122-
char *res1_p,
123-
char *res2_p,
124-
const std::vector<sycl::event> &depends = {})
125-
{
126-
return ew_cmn_ns::unary_two_outputs_contig_impl<
127-
argTy, FrexpOutputType, FrexpContigFunctor, frexp_contig_kernel>(
128-
exec_q, nelems, arg_p, res1_p, res2_p, depends);
129-
}
130-
131-
template <typename fnT, typename T>
132-
struct FrexpContigFactory
133-
{
134-
fnT get()
135-
{
136-
if constexpr (std::is_same_v<typename FrexpOutputType<T>::value_type1,
137-
void> ||
138-
std::is_same_v<typename FrexpOutputType<T>::value_type2,
139-
void>)
140-
{
141-
fnT fn = nullptr;
142-
return fn;
143-
}
144-
else {
145-
fnT fn = frexp_contig_impl<T>;
146-
return fn;
147-
}
148-
}
149-
};
150-
151-
template <typename T1, typename T2, typename T3, typename T4>
152-
class frexp_strided_kernel;
153-
154-
template <typename argTy>
155-
sycl::event
156-
frexp_strided_impl(sycl::queue &exec_q,
157-
size_t nelems,
158-
int nd,
159-
const ssize_t *shape_and_strides,
160-
const char *arg_p,
161-
ssize_t arg_offset,
162-
char *res1_p,
163-
ssize_t res1_offset,
164-
char *res2_p,
165-
ssize_t res2_offset,
166-
const std::vector<sycl::event> &depends,
167-
const std::vector<sycl::event> &additional_depends)
168-
{
169-
return ew_cmn_ns::unary_two_outputs_strided_impl<
170-
argTy, FrexpOutputType, FrexpStridedFunctor, frexp_strided_kernel>(
171-
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p,
172-
res1_offset, res2_p, res2_offset, depends, additional_depends);
173-
}
174-
175-
template <typename fnT, typename T>
176-
struct FrexpStridedFactory
177-
{
178-
fnT get()
179-
{
180-
if constexpr (std::is_same_v<typename FrexpOutputType<T>::value_type1,
181-
void> ||
182-
std::is_same_v<typename FrexpOutputType<T>::value_type2,
183-
void>)
184-
{
185-
fnT fn = nullptr;
186-
return fn;
187-
}
188-
else {
189-
fnT fn = frexp_strided_impl<T>;
190-
return fn;
191-
}
192-
}
193-
};
194-
195-
template <typename fnT, typename T>
196-
struct FrexpTypeMapFactory
197-
{
198-
/*! @brief get typeid for output type of sycl::frexp(T x) */
199-
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value,
200-
std::pair<int, int>>
201-
get()
202-
{
203-
using rT1 = typename FrexpOutputType<T>::value_type1;
204-
using rT2 = typename FrexpOutputType<T>::value_type2;
205-
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(),
206-
td_ns::GetTypeid<rT2>{}.get());
207-
}
208-
};
209-
210107
using ew_cmn_ns::unary_two_outputs_contig_impl_fn_ptr_t;
211108
using ew_cmn_ns::unary_two_outputs_strided_impl_fn_ptr_t;
212109

@@ -216,15 +113,7 @@ static std::pair<int, int> frexp_output_typeid_vector[td_ns::num_types];
216113
static unary_two_outputs_strided_impl_fn_ptr_t
217114
frexp_strided_dispatch_vector[td_ns::num_types];
218115

219-
void populate_frexp_dispatch_vectors(void)
220-
{
221-
init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t,
222-
FrexpContigFactory>(frexp_contig_dispatch_vector);
223-
init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t,
224-
FrexpStridedFactory>(frexp_strided_dispatch_vector);
225-
init_dispatch_vector<std::pair<int, int>, FrexpTypeMapFactory>(
226-
frexp_output_typeid_vector);
227-
};
116+
MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(frexp);
228117
} // namespace impl
229118

230119
void init_frexp(py::module_ m)

dpnp/backend/extensions/ufunc/elementwise_functions/modf.cpp

Lines changed: 5 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include "kernels/elementwise_functions/modf.hpp"
3939
#include "modf.hpp"
40+
#include "populate.hpp"
4041

4142
// include a local copy of elementwise common header from dpctl tensor:
4243
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
@@ -65,10 +66,9 @@ namespace td_int_ns = py_int::type_dispatch;
6566
namespace td_ns = dpctl::tensor::type_dispatch;
6667

6768
using dpnp::kernels::modf::ModfFunctor;
68-
using ext::common::init_dispatch_vector;
6969

7070
template <typename T>
71-
struct ModfOutputType
71+
struct OutputType
7272
{
7373
using table_type = std::disjunction< // disjunction is C++17
7474
// feature, supported by DPC++
@@ -81,15 +81,13 @@ struct ModfOutputType
8181
using value_type2 = typename table_type::result_type2;
8282
};
8383

84-
// contiguous implementation
85-
8684
template <typename argTy,
8785
typename resTy1 = argTy,
8886
typename resTy2 = argTy,
8987
std::uint8_t vec_sz = 4u,
9088
std::uint8_t n_vecs = 2u,
9189
bool enable_sg_loadstore = true>
92-
using ModfContigFunctor =
90+
using ContigFunctor =
9391
ew_cmn_ns::UnaryTwoOutputsContigFunctor<argTy,
9492
resTy1,
9593
resTy2,
@@ -98,115 +96,14 @@ using ModfContigFunctor =
9896
n_vecs,
9997
enable_sg_loadstore>;
10098

101-
// strided implementation
102-
10399
template <typename argTy, typename resTy1, typename resTy2, typename IndexerT>
104-
using ModfStridedFunctor = ew_cmn_ns::UnaryTwoOutputsStridedFunctor<
100+
using StridedFunctor = ew_cmn_ns::UnaryTwoOutputsStridedFunctor<
105101
argTy,
106102
resTy1,
107103
resTy2,
108104
IndexerT,
109105
ModfFunctor<argTy, resTy1, resTy2>>;
110106

111-
template <typename T1,
112-
typename T2,
113-
typename T3,
114-
unsigned int vec_sz,
115-
unsigned int n_vecs>
116-
class modf_contig_kernel;
117-
118-
template <typename argTy>
119-
sycl::event modf_contig_impl(sycl::queue &exec_q,
120-
size_t nelems,
121-
const char *arg_p,
122-
char *res1_p,
123-
char *res2_p,
124-
const std::vector<sycl::event> &depends = {})
125-
{
126-
return ew_cmn_ns::unary_two_outputs_contig_impl<
127-
argTy, ModfOutputType, ModfContigFunctor, modf_contig_kernel>(
128-
exec_q, nelems, arg_p, res1_p, res2_p, depends);
129-
}
130-
131-
template <typename fnT, typename T>
132-
struct ModfContigFactory
133-
{
134-
fnT get()
135-
{
136-
if constexpr (std::is_same_v<typename ModfOutputType<T>::value_type1,
137-
void> ||
138-
std::is_same_v<typename ModfOutputType<T>::value_type2,
139-
void>)
140-
{
141-
fnT fn = nullptr;
142-
return fn;
143-
}
144-
else {
145-
fnT fn = modf_contig_impl<T>;
146-
return fn;
147-
}
148-
}
149-
};
150-
151-
template <typename T1, typename T2, typename T3, typename T4>
152-
class modf_strided_kernel;
153-
154-
template <typename argTy>
155-
sycl::event
156-
modf_strided_impl(sycl::queue &exec_q,
157-
size_t nelems,
158-
int nd,
159-
const ssize_t *shape_and_strides,
160-
const char *arg_p,
161-
ssize_t arg_offset,
162-
char *res1_p,
163-
ssize_t res1_offset,
164-
char *res2_p,
165-
ssize_t res2_offset,
166-
const std::vector<sycl::event> &depends,
167-
const std::vector<sycl::event> &additional_depends)
168-
{
169-
return ew_cmn_ns::unary_two_outputs_strided_impl<
170-
argTy, ModfOutputType, ModfStridedFunctor, modf_strided_kernel>(
171-
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p,
172-
res1_offset, res2_p, res2_offset, depends, additional_depends);
173-
}
174-
175-
template <typename fnT, typename T>
176-
struct ModfStridedFactory
177-
{
178-
fnT get()
179-
{
180-
if constexpr (std::is_same_v<typename ModfOutputType<T>::value_type1,
181-
void> ||
182-
std::is_same_v<typename ModfOutputType<T>::value_type2,
183-
void>)
184-
{
185-
fnT fn = nullptr;
186-
return fn;
187-
}
188-
else {
189-
fnT fn = modf_strided_impl<T>;
190-
return fn;
191-
}
192-
}
193-
};
194-
195-
template <typename fnT, typename T>
196-
struct ModfTypeMapFactory
197-
{
198-
/*! @brief get typeid for output type of sycl::modf(T x) */
199-
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value,
200-
std::pair<int, int>>
201-
get()
202-
{
203-
using rT1 = typename ModfOutputType<T>::value_type1;
204-
using rT2 = typename ModfOutputType<T>::value_type2;
205-
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(),
206-
td_ns::GetTypeid<rT2>{}.get());
207-
}
208-
};
209-
210107
using ew_cmn_ns::unary_two_outputs_contig_impl_fn_ptr_t;
211108
using ew_cmn_ns::unary_two_outputs_strided_impl_fn_ptr_t;
212109

@@ -216,15 +113,7 @@ static std::pair<int, int> modf_output_typeid_vector[td_ns::num_types];
216113
static unary_two_outputs_strided_impl_fn_ptr_t
217114
modf_strided_dispatch_vector[td_ns::num_types];
218115

219-
void populate_modf_dispatch_vectors(void)
220-
{
221-
init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t,
222-
ModfContigFactory>(modf_contig_dispatch_vector);
223-
init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t,
224-
ModfStridedFactory>(modf_strided_dispatch_vector);
225-
init_dispatch_vector<std::pair<int, int>, ModfTypeMapFactory>(
226-
modf_output_typeid_vector);
227-
};
116+
MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(modf);
228117
} // namespace impl
229118

230119
void init_modf(py::module_ m)

0 commit comments

Comments
 (0)