Skip to content

Commit 6b56e99

Browse files
Added BinaryContigMatrixContigRowBroadcastFunctor and RowMatrix variant
Added to common.hpp templated callable to generate kernels for arbitrary binary operator. Applied that to addition code. Also implemented true_divide operator, exported as _tensor_impl._divide
1 parent b088a52 commit 6b56e99

File tree

4 files changed

+817
-52
lines changed

4 files changed

+817
-52
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ sycl::event add_contig_impl(sycl::queue exec_q,
163163
py::ssize_t res_offset,
164164
const std::vector<sycl::event> &depends = {})
165165
{
166-
sycl::event add_ev = exec_q.submit([&](sycl::handler &cgh) {
166+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
167167
cgh.depends_on(depends);
168168

169169
size_t lws = 64;
@@ -188,7 +188,7 @@ sycl::event add_contig_impl(sycl::queue exec_q,
188188
AddContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
189189
arg1_tp, arg2_tp, res_tp, nelems));
190190
});
191-
return add_ev;
191+
return comp_ev;
192192
}
193193

194194
template <typename fnT, typename T1, typename T2> struct AddContigFactory
@@ -249,7 +249,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
249249
const std::vector<sycl::event> &depends,
250250
const std::vector<sycl::event> &additional_depends)
251251
{
252-
sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) {
252+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
253253
cgh.depends_on(depends);
254254
cgh.depends_on(additional_depends);
255255

@@ -270,7 +270,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
270270
{nelems}, AddStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
271271
arg1_tp, arg2_tp, res_tp, indexer));
272272
});
273-
return abs_ev;
273+
return comp_ev;
274274
}
275275

276276
template <typename fnT, typename T1, typename T2> struct AddStridedFactory
@@ -290,7 +290,7 @@ template <typename fnT, typename T1, typename T2> struct AddStridedFactory
290290
};
291291

292292
template <typename argT1, typename argT2, typename resT>
293-
class add_matrix_vector_broadcast_sg_krn;
293+
class add_matrix_row_broadcast_sg_krn;
294294

295295
typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)(
296296
sycl::queue,
@@ -305,6 +305,14 @@ typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)(
305305
py::ssize_t,
306306
const std::vector<sycl::event> &);
307307

308+
template <typename argT1, typename argT2, typename resT>
309+
using AddContigMatrixContigRowBroadcastingFunctor =
310+
elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor<
311+
argT1,
312+
argT2,
313+
resT,
314+
AddFunctor<argT1, argT2, resT>>;
315+
308316
template <typename argT1, typename argT2, typename resT>
309317
sycl::event add_contig_matrix_contig_row_broadcast_impl(
310318
sycl::queue exec_q,
@@ -361,41 +369,11 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
361369
size_t n_groups = (n_elems + lws - 1) / lws;
362370
auto gwsRange = sycl::range<1>(n_groups * lws);
363371

364-
cgh.parallel_for<class add_matrix_vector_broadcast_sg_krn<argT1, argT2, resT>>(
372+
cgh.parallel_for<
373+
class add_matrix_row_broadcast_sg_krn<argT1, argT2, resT>>(
365374
sycl::nd_range<1>(gwsRange, lwsRange),
366-
[=](sycl::nd_item<1> ndit)
367-
{
368-
auto sg = ndit.get_sub_group();
369-
size_t gid = ndit.get_global_linear_id();
370-
371-
std::uint8_t sgSize = sg.get_local_range()[0];
372-
size_t base = gid - sg.get_local_id()[0];
373-
374-
if (base + sgSize < n_elems) {
375-
using in_ptrT1 =
376-
sycl::multi_ptr<const argT1,
377-
sycl::access::address_space::global_space>;
378-
using in_ptrT2 =
379-
sycl::multi_ptr<const argT2,
380-
sycl::access::address_space::global_space>;
381-
using res_ptrT =
382-
sycl::multi_ptr<resT,
383-
sycl::access::address_space::global_space>;
384-
385-
const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
386-
const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1]));
387-
388-
resT res_el = mat_el + vec_el;
389-
390-
sg.store(res_ptrT(&res[base]), res_el);
391-
}
392-
else {
393-
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
394-
k += sgSize) {
395-
res[k] = mat[k] + padded_vec[k % n1];
396-
}
397-
}
398-
});
375+
AddContigMatrixContigRowBroadcastingFunctor<argT1, argT2, resT>(
376+
mat, padded_vec, res, n_elems, n1));
399377
});
400378

401379
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
@@ -413,13 +391,12 @@ struct AddContigMatrixContigRowBroadcastFactory
413391
{
414392
fnT get()
415393
{
416-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
417-
void>) {
394+
using resT = typename AddOutputType<T1, T2>::value_type;
395+
if constexpr (std::is_same_v<resT, void>) {
418396
fnT fn = nullptr;
419397
return fn;
420398
}
421399
else {
422-
using resT = typename AddOutputType<T1, T2>::value_type;
423400
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
424401
dpctl::tensor::type_utils::is_complex<T2>::value ||
425402
dpctl::tensor::type_utils::is_complex<resT>::value)
@@ -474,13 +451,12 @@ struct AddContigRowContigMatrixBroadcastFactory
474451
{
475452
fnT get()
476453
{
477-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
478-
void>) {
454+
using resT = typename AddOutputType<T1, T2>::value_type;
455+
if constexpr (std::is_same_v<resT, void>) {
479456
fnT fn = nullptr;
480457
return fn;
481458
}
482459
else {
483-
using resT = typename AddOutputType<T1, T2>::value_type;
484460
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
485461
dpctl::tensor::type_utils::is_complex<T2>::value ||
486462
dpctl::tensor::type_utils::is_complex<resT>::value)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,130 @@ struct BinaryStridedFunctor
420420
}
421421
};
422422

423+
template <typename argT1,
424+
typename argT2,
425+
typename resT,
426+
typename BinaryOperatorT>
427+
struct BinaryContigMatrixContigRowBroadcastingFunctor
428+
{
429+
private:
430+
const argT1 *mat;
431+
const argT2 *padded_vec;
432+
resT *res;
433+
size_t n_elems;
434+
size_t n1;
435+
436+
public:
437+
BinaryContigMatrixContigRowBroadcastingFunctor(const argT1 *mat_tp,
438+
const argT2 *row_tp,
439+
resT *res_tp,
440+
size_t n_elems_in_mat,
441+
size_t n_elems_in_row)
442+
: mat(mat_tp), padded_vec(row_tp), res(res_tp), n_elems(n_elems_in_mat),
443+
n1(n_elems_in_row)
444+
{
445+
}
446+
447+
void operator()(sycl::nd_item<1> ndit) const
448+
{
449+
BinaryOperatorT op{};
450+
static_assert(BinaryOperatorT::supports_sg_loadstore::value);
451+
452+
auto sg = ndit.get_sub_group();
453+
size_t gid = ndit.get_global_linear_id();
454+
455+
std::uint8_t sgSize = sg.get_local_range()[0];
456+
size_t base = gid - sg.get_local_id()[0];
457+
458+
if (base + sgSize < n_elems) {
459+
using in_ptrT1 =
460+
sycl::multi_ptr<const argT1,
461+
sycl::access::address_space::global_space>;
462+
using in_ptrT2 =
463+
sycl::multi_ptr<const argT2,
464+
sycl::access::address_space::global_space>;
465+
using res_ptrT =
466+
sycl::multi_ptr<resT,
467+
sycl::access::address_space::global_space>;
468+
469+
const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
470+
const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1]));
471+
472+
resT res_el = op(mat_el, vec_el);
473+
474+
sg.store(res_ptrT(&res[base]), res_el);
475+
}
476+
else {
477+
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
478+
k += sgSize) {
479+
res[k] = op(mat[k], padded_vec[k % n1]);
480+
}
481+
}
482+
}
483+
};
484+
485+
template <typename argT1,
486+
typename argT2,
487+
typename resT,
488+
typename BinaryOperatorT>
489+
struct BinaryContigRowContigMatrixBroadcastingFunctor
490+
{
491+
private:
492+
const argT1 *padded_vec;
493+
const argT2 *mat;
494+
resT *res;
495+
size_t n_elems;
496+
size_t n1;
497+
498+
public:
499+
BinaryContigRowContigMatrixBroadcastingFunctor(const argT1 *row_tp,
500+
const argT2 *mat_tp,
501+
resT *res_tp,
502+
size_t n_elems_in_mat,
503+
size_t n_elems_in_row)
504+
: padded_vec(row_tp), mat(mat_tp), res(res_tp), n_elems(n_elems_in_mat),
505+
n1(n_elems_in_row)
506+
{
507+
}
508+
509+
void operator()(sycl::nd_item<1> ndit) const
510+
{
511+
BinaryOperatorT op{};
512+
static_assert(BinaryOperatorT::supports_sg_loadstore::value);
513+
514+
auto sg = ndit.get_sub_group();
515+
size_t gid = ndit.get_global_linear_id();
516+
517+
std::uint8_t sgSize = sg.get_local_range()[0];
518+
size_t base = gid - sg.get_local_id()[0];
519+
520+
if (base + sgSize < n_elems) {
521+
using in_ptrT1 =
522+
sycl::multi_ptr<const argT1,
523+
sycl::access::address_space::global_space>;
524+
using in_ptrT2 =
525+
sycl::multi_ptr<const argT2,
526+
sycl::access::address_space::global_space>;
527+
using res_ptrT =
528+
sycl::multi_ptr<resT,
529+
sycl::access::address_space::global_space>;
530+
531+
const argT2 mat_el = sg.load(in_ptrT2(&mat[base]));
532+
const argT1 vec_el = sg.load(in_ptrT1(&padded_vec[base % n1]));
533+
534+
resT res_el = op(vec_el, mat_el);
535+
536+
sg.store(res_ptrT(&res[base]), res_el);
537+
}
538+
else {
539+
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
540+
k += sgSize) {
541+
res[k] = op(padded_vec[k % n1], mat[k]);
542+
}
543+
}
544+
}
545+
};
546+
423547
} // namespace elementwise_common
424548
} // namespace kernels
425549
} // namespace tensor

0 commit comments

Comments
 (0)