@@ -397,15 +397,14 @@ std::pair<sycl::event, sycl::event>
397397 return std::make_pair (arg_cleanup_ev, choose_generic_ev);
398398}
399399
400- template <typename fnT, typename IndT, typename T>
401- struct ChooseWrapFactory
400+ template <typename fnT, typename IndT, typename T, typename Index >
401+ struct ChooseFactory
402402{
403403 fnT get ()
404404 {
405405 if constexpr (std::is_integral<IndT>::value &&
406406 !std::is_same<IndT, bool >::value) {
407- using dpctl::tensor::indexing_utils::WrapIndex;
408- fnT fn = kernels::choose_impl<WrapIndex<IndT>, IndT, T>;
407+ fnT fn = kernels::choose_impl<Index, IndT, T>;
409408 return fn;
410409 }
411410 else {
@@ -415,23 +414,14 @@ struct ChooseWrapFactory
415414 }
416415};
417416
417+ using dpctl::tensor::indexing_utils::ClipIndex;
418+ using dpctl::tensor::indexing_utils::WrapIndex;
419+
418420template <typename fnT, typename IndT, typename T>
419- struct ChooseClipFactory
420- {
421- fnT get ()
422- {
423- if constexpr (std::is_integral<IndT>::value &&
424- !std::is_same<IndT, bool >::value) {
425- using dpctl::tensor::indexing_utils::ClipIndex;
426- fnT fn = kernels::choose_impl<ClipIndex<IndT>, IndT, T>;
427- return fn;
428- }
429- else {
430- fnT fn = nullptr ;
431- return fn;
432- }
433- }
434- };
421+ using ChooseWrapFactory = ChooseFactory<fnT, IndT, T, WrapIndex<IndT>>;
422+
423+ template <typename fnT, typename IndT, typename T>
424+ using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
435425
436426void init_choose_dispatch_tables (void )
437427{
0 commit comments