@@ -167,32 +167,10 @@ sycl::event not_equal_contig_impl(sycl::queue exec_q,
167
167
py::ssize_t res_offset,
168
168
const std::vector<sycl::event> &depends = {})
169
169
{
170
- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
171
- cgh.depends_on (depends);
172
-
173
- size_t lws = 64 ;
174
- constexpr unsigned int vec_sz = 4 ;
175
- constexpr unsigned int n_vecs = 2 ;
176
- const size_t n_groups =
177
- ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
178
- const auto gws_range = sycl::range<1 >(n_groups * lws);
179
- const auto lws_range = sycl::range<1 >(lws);
180
-
181
- using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
182
-
183
- const argTy1 *arg1_tp =
184
- reinterpret_cast <const argTy1 *>(arg1_p) + arg1_offset;
185
- const argTy2 *arg2_tp =
186
- reinterpret_cast <const argTy2 *>(arg2_p) + arg2_offset;
187
- resTy *res_tp = reinterpret_cast <resTy *>(res_p) + res_offset;
188
-
189
- cgh.parallel_for <
190
- not_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
191
- sycl::nd_range<1 >(gws_range, lws_range),
192
- NotEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
193
- arg1_tp, arg2_tp, res_tp, nelems));
194
- });
195
- return comp_ev;
170
+ return elementwise_common::binary_contig_impl<
171
+ argTy1, argTy2, NotEqualOutputType, NotEqualContigFunctor,
172
+ not_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
173
+ arg2_offset, res_p, res_offset, depends);
196
174
}
197
175
198
176
template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
@@ -215,7 +193,7 @@ template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
215
193
216
194
template <typename fnT, typename T1, typename T2> struct NotEqualTypeMapFactory
217
195
{
218
- /* ! @brief get typeid for output type of operator()= =(x, y), always bool */
196
+ /* ! @brief get typeid for output type of operator()! =(x, y), always bool */
219
197
std::enable_if_t <std::is_same<fnT, int >::value, int > get ()
220
198
{
221
199
using rT = typename NotEqualOutputType<T1, T2>::value_type;
@@ -241,28 +219,11 @@ not_equal_strided_impl(sycl::queue exec_q,
241
219
const std::vector<sycl::event> &depends,
242
220
const std::vector<sycl::event> &additional_depends)
243
221
{
244
- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
245
- cgh.depends_on (depends);
246
- cgh.depends_on (additional_depends);
247
-
248
- using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
249
-
250
- using IndexerT =
251
- typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
252
-
253
- IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
254
- shape_and_strides};
255
-
256
- const argTy1 *arg1_tp = reinterpret_cast <const argTy1 *>(arg1_p);
257
- const argTy2 *arg2_tp = reinterpret_cast <const argTy2 *>(arg2_p);
258
- resTy *res_tp = reinterpret_cast <resTy *>(res_p);
259
-
260
- cgh.parallel_for <
261
- not_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
262
- {nelems}, NotEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
263
- arg1_tp, arg2_tp, res_tp, indexer));
264
- });
265
- return comp_ev;
222
+ return elementwise_common::binary_strided_impl<
223
+ argTy1, argTy2, NotEqualOutputType, NotEqualStridedFunctor,
224
+ not_equal_strided_strided_kernel>(
225
+ exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
226
+ arg2_offset, res_p, res_offset, depends, additional_depends);
266
227
}
267
228
268
229
template <typename fnT, typename T1, typename T2> struct NotEqualStridedFactory
0 commit comments