Skip to content

Commit f2da70a

Browse files
committed
Reduce choose dispatching code duplication
1 parent 2897f07 commit f2da70a

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,14 @@ std::pair<sycl::event, sycl::event>
397397
return std::make_pair(arg_cleanup_ev, choose_generic_ev);
398398
}
399399

400-
template <typename fnT, typename IndT, typename T>
401-
struct ChooseWrapFactory
400+
template <typename fnT, typename IndT, typename T, typename Index>
401+
struct ChooseFactory
402402
{
403403
fnT get()
404404
{
405405
if constexpr (std::is_integral<IndT>::value &&
406406
!std::is_same<IndT, bool>::value) {
407-
using dpctl::tensor::indexing_utils::WrapIndex;
408-
fnT fn = kernels::choose_impl<WrapIndex<IndT>, IndT, T>;
407+
fnT fn = kernels::choose_impl<Index, IndT, T>;
409408
return fn;
410409
}
411410
else {
@@ -415,23 +414,14 @@ struct ChooseWrapFactory
415414
}
416415
};
417416

417+
using dpctl::tensor::indexing_utils::ClipIndex;
418+
using dpctl::tensor::indexing_utils::WrapIndex;
419+
418420
template <typename fnT, typename IndT, typename T>
419-
struct ChooseClipFactory
420-
{
421-
fnT get()
422-
{
423-
if constexpr (std::is_integral<IndT>::value &&
424-
!std::is_same<IndT, bool>::value) {
425-
using dpctl::tensor::indexing_utils::ClipIndex;
426-
fnT fn = kernels::choose_impl<ClipIndex<IndT>, IndT, T>;
427-
return fn;
428-
}
429-
else {
430-
fnT fn = nullptr;
431-
return fn;
432-
}
433-
}
434-
};
421+
using ChooseWrapFactory = ChooseFactory<fnT, IndT, T, WrapIndex<IndT>>;
422+
423+
template <typename fnT, typename IndT, typename T>
424+
using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
435425

436426
void init_choose_dispatch_tables(void)
437427
{

0 commit comments

Comments
 (0)