Skip to content

Commit 4a5a84a

Browse files
Use binary function template for non_equal
1 parent 644805f commit 4a5a84a

File tree

1 file changed

+10
-49
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+10
-49
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -167,32 +167,10 @@ sycl::event not_equal_contig_impl(sycl::queue exec_q,
167167
py::ssize_t res_offset,
168168
const std::vector<sycl::event> &depends = {})
169169
{
170-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
171-
cgh.depends_on(depends);
172-
173-
size_t lws = 64;
174-
constexpr unsigned int vec_sz = 4;
175-
constexpr unsigned int n_vecs = 2;
176-
const size_t n_groups =
177-
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
178-
const auto gws_range = sycl::range<1>(n_groups * lws);
179-
const auto lws_range = sycl::range<1>(lws);
180-
181-
using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
182-
183-
const argTy1 *arg1_tp =
184-
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
185-
const argTy2 *arg2_tp =
186-
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
187-
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
188-
189-
cgh.parallel_for<
190-
not_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
191-
sycl::nd_range<1>(gws_range, lws_range),
192-
NotEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
193-
arg1_tp, arg2_tp, res_tp, nelems));
194-
});
195-
return comp_ev;
170+
return elementwise_common::binary_contig_impl<
171+
argTy1, argTy2, NotEqualOutputType, NotEqualContigFunctor,
172+
not_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
173+
arg2_offset, res_p, res_offset, depends);
196174
}
197175

198176
template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
@@ -215,7 +193,7 @@ template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
215193

216194
template <typename fnT, typename T1, typename T2> struct NotEqualTypeMapFactory
217195
{
218-
/*! @brief get typeid for output type of operator()==(x, y), always bool */
196+
/*! @brief get typeid for output type of operator()!=(x, y), always bool */
219197
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
220198
{
221199
using rT = typename NotEqualOutputType<T1, T2>::value_type;
@@ -241,28 +219,11 @@ not_equal_strided_impl(sycl::queue exec_q,
241219
const std::vector<sycl::event> &depends,
242220
const std::vector<sycl::event> &additional_depends)
243221
{
244-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
245-
cgh.depends_on(depends);
246-
cgh.depends_on(additional_depends);
247-
248-
using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
249-
250-
using IndexerT =
251-
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
252-
253-
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
254-
shape_and_strides};
255-
256-
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
257-
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
258-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
259-
260-
cgh.parallel_for<
261-
not_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
262-
{nelems}, NotEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
263-
arg1_tp, arg2_tp, res_tp, indexer));
264-
});
265-
return comp_ev;
222+
return elementwise_common::binary_strided_impl<
223+
argTy1, argTy2, NotEqualOutputType, NotEqualStridedFunctor,
224+
not_equal_strided_strided_kernel>(
225+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
226+
arg2_offset, res_p, res_offset, depends, additional_depends);
266227
}
267228

268229
template <typename fnT, typename T1, typename T2> struct NotEqualStridedFactory

0 commit comments

Comments
 (0)