@@ -163,7 +163,7 @@ sycl::event add_contig_impl(sycl::queue exec_q,
163
163
py::ssize_t res_offset,
164
164
const std::vector<sycl::event> &depends = {})
165
165
{
166
- sycl::event add_ev = exec_q.submit ([&](sycl::handler &cgh) {
166
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
167
167
cgh.depends_on (depends);
168
168
169
169
size_t lws = 64 ;
@@ -188,7 +188,7 @@ sycl::event add_contig_impl(sycl::queue exec_q,
188
188
AddContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
189
189
arg1_tp, arg2_tp, res_tp, nelems));
190
190
});
191
- return add_ev ;
191
+ return comp_ev ;
192
192
}
193
193
194
194
template <typename fnT, typename T1, typename T2> struct AddContigFactory
@@ -249,7 +249,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
249
249
const std::vector<sycl::event> &depends,
250
250
const std::vector<sycl::event> &additional_depends)
251
251
{
252
- sycl::event abs_ev = exec_q.submit ([&](sycl::handler &cgh) {
252
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
253
253
cgh.depends_on (depends);
254
254
cgh.depends_on (additional_depends);
255
255
@@ -270,7 +270,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
270
270
{nelems}, AddStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
271
271
arg1_tp, arg2_tp, res_tp, indexer));
272
272
});
273
- return abs_ev ;
273
+ return comp_ev ;
274
274
}
275
275
276
276
template <typename fnT, typename T1, typename T2> struct AddStridedFactory
@@ -290,7 +290,7 @@ template <typename fnT, typename T1, typename T2> struct AddStridedFactory
290
290
};
291
291
292
292
template <typename argT1, typename argT2, typename resT>
293
- class add_matrix_vector_broadcast_sg_krn ;
293
+ class add_matrix_row_broadcast_sg_krn ;
294
294
295
295
typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t )(
296
296
sycl::queue,
@@ -305,6 +305,14 @@ typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)(
305
305
py::ssize_t ,
306
306
const std::vector<sycl::event> &);
307
307
308
+ template <typename argT1, typename argT2, typename resT>
309
+ using AddContigMatrixContigRowBroadcastingFunctor =
310
+ elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor<
311
+ argT1,
312
+ argT2,
313
+ resT,
314
+ AddFunctor<argT1, argT2, resT>>;
315
+
308
316
template <typename argT1, typename argT2, typename resT>
309
317
sycl::event add_contig_matrix_contig_row_broadcast_impl (
310
318
sycl::queue exec_q,
@@ -361,41 +369,11 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
361
369
size_t n_groups = (n_elems + lws - 1 ) / lws;
362
370
auto gwsRange = sycl::range<1 >(n_groups * lws);
363
371
364
- cgh.parallel_for <class add_matrix_vector_broadcast_sg_krn <argT1, argT2, resT>>(
372
+ cgh.parallel_for <
373
+ class add_matrix_row_broadcast_sg_krn <argT1, argT2, resT>>(
365
374
sycl::nd_range<1 >(gwsRange, lwsRange),
366
- [=](sycl::nd_item<1 > ndit)
367
- {
368
- auto sg = ndit.get_sub_group ();
369
- size_t gid = ndit.get_global_linear_id ();
370
-
371
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
372
- size_t base = gid - sg.get_local_id ()[0 ];
373
-
374
- if (base + sgSize < n_elems) {
375
- using in_ptrT1 =
376
- sycl::multi_ptr<const argT1,
377
- sycl::access::address_space::global_space>;
378
- using in_ptrT2 =
379
- sycl::multi_ptr<const argT2,
380
- sycl::access::address_space::global_space>;
381
- using res_ptrT =
382
- sycl::multi_ptr<resT,
383
- sycl::access::address_space::global_space>;
384
-
385
- const argT1 mat_el = sg.load (in_ptrT1 (&mat[base]));
386
- const argT2 vec_el = sg.load (in_ptrT2 (&padded_vec[base % n1]));
387
-
388
- resT res_el = mat_el + vec_el;
389
-
390
- sg.store (res_ptrT (&res[base]), res_el);
391
- }
392
- else {
393
- for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
394
- k += sgSize) {
395
- res[k] = mat[k] + padded_vec[k % n1];
396
- }
397
- }
398
- });
375
+ AddContigMatrixContigRowBroadcastingFunctor<argT1, argT2, resT>(
376
+ mat, padded_vec, res, n_elems, n1));
399
377
});
400
378
401
379
sycl::event tmp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
@@ -413,13 +391,12 @@ struct AddContigMatrixContigRowBroadcastFactory
413
391
{
414
392
fnT get ()
415
393
{
416
- if constexpr (std::is_same_v< typename AddOutputType<T1, T2>::value_type,
417
- void >) {
394
+ using resT = typename AddOutputType<T1, T2>::value_type;
395
+ if constexpr (std::is_same_v<resT, void >) {
418
396
fnT fn = nullptr ;
419
397
return fn;
420
398
}
421
399
else {
422
- using resT = typename AddOutputType<T1, T2>::value_type;
423
400
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
424
401
dpctl::tensor::type_utils::is_complex<T2>::value ||
425
402
dpctl::tensor::type_utils::is_complex<resT>::value)
@@ -474,13 +451,12 @@ struct AddContigRowContigMatrixBroadcastFactory
474
451
{
475
452
fnT get ()
476
453
{
477
- if constexpr (std::is_same_v< typename AddOutputType<T1, T2>::value_type,
478
- void >) {
454
+ using resT = typename AddOutputType<T1, T2>::value_type;
455
+ if constexpr (std::is_same_v<resT, void >) {
479
456
fnT fn = nullptr ;
480
457
return fn;
481
458
}
482
459
else {
483
- using resT = typename AddOutputType<T1, T2>::value_type;
484
460
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
485
461
dpctl::tensor::type_utils::is_complex<T2>::value ||
486
462
dpctl::tensor::type_utils::is_complex<resT>::value)
0 commit comments