46
46
47
47
#include " integer_advanced_indexing.hpp"
48
48
49
- #define INDEXING_MODES 2
50
- #define WRAP_MODE 0
51
- #define CLIP_MODE 1
52
-
53
49
namespace dpctl
54
50
{
55
51
namespace tensor
@@ -62,11 +58,17 @@ namespace td_ns = dpctl::tensor::type_dispatch;
62
58
using dpctl::tensor::kernels::indexing::put_fn_ptr_t ;
63
59
using dpctl::tensor::kernels::indexing::take_fn_ptr_t ;
64
60
65
- static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types]
66
- [td_ns::num_types];
61
+ static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types]
62
+ [td_ns::num_types];
63
+
64
+ static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types]
65
+ [td_ns::num_types];
66
+
67
+ static put_fn_ptr_t put_wrap_dispatch_table[INDEXING_MODES][td_ns::num_types]
68
+ [td_ns::num_types];
67
69
68
- static put_fn_ptr_t put_dispatch_table [INDEXING_MODES][td_ns::num_types]
69
- [td_ns::num_types];
70
+ static put_fn_ptr_t put_clip_dispatch_table [INDEXING_MODES][td_ns::num_types]
71
+ [td_ns::num_types];
70
72
71
73
namespace py = pybind11;
72
74
@@ -486,7 +488,8 @@ py_take(const dpctl::tensor::usm_ndarray &src,
486
488
std::end (pack_deps));
487
489
all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
488
490
489
- auto fn = take_dispatch_table[mode][src_type_id][ind_type_id];
491
+ auto fn = mode ? take_wrap_dispatch_table[src_type_id][ind_type_id]
492
+ : take_clip_dispatch_table[src_type_id][ind_type_id];
490
493
491
494
if (fn == nullptr ) {
492
495
sycl::event::wait (host_task_events);
@@ -755,7 +758,8 @@ py_put(const dpctl::tensor::usm_ndarray &dst,
755
758
std::end (pack_deps));
756
759
all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
757
760
758
- auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id];
761
+ auto fn = mode ? put_wrap_dispatch_table[src_type_id][ind_type_id]
762
+ : put_clip_dispatch_table[src_type_id][ind_type_id];
759
763
760
764
if (fn == nullptr ) {
761
765
sycl::event::wait (host_task_events);
@@ -790,20 +794,20 @@ void init_advanced_indexing_dispatch_tables(void)
790
794
using dpctl::tensor::kernels::indexing::TakeClipFactory;
791
795
DispatchTableBuilder<take_fn_ptr_t , TakeClipFactory, num_types>
792
796
dtb_takeclip;
793
- dtb_takeclip.populate_dispatch_table (take_dispatch_table[CLIP_MODE] );
797
+ dtb_takeclip.populate_dispatch_table (take_clip_dispatch_table );
794
798
795
799
using dpctl::tensor::kernels::indexing::TakeWrapFactory;
796
800
DispatchTableBuilder<take_fn_ptr_t , TakeWrapFactory, num_types>
797
801
dtb_takewrap;
798
- dtb_takewrap.populate_dispatch_table (take_dispatch_table[WRAP_MODE] );
802
+ dtb_takewrap.populate_dispatch_table (take_wrap_dispatch_table );
799
803
800
804
using dpctl::tensor::kernels::indexing::PutClipFactory;
801
805
DispatchTableBuilder<put_fn_ptr_t , PutClipFactory, num_types> dtb_putclip;
802
- dtb_putclip.populate_dispatch_table (put_dispatch_table[CLIP_MODE] );
806
+ dtb_putclip.populate_dispatch_table (put_clip_dispatch_table );
803
807
804
808
using dpctl::tensor::kernels::indexing::PutWrapFactory;
805
809
DispatchTableBuilder<put_fn_ptr_t , PutWrapFactory, num_types> dtb_putwrap;
806
- dtb_putwrap.populate_dispatch_table (put_dispatch_table[WRAP_MODE] );
810
+ dtb_putwrap.populate_dispatch_table (put_wrap_dispatch_table );
807
811
}
808
812
809
813
} // namespace py_internal
0 commit comments