Skip to content

Commit e8e66ed

Browse files
committed
Change integer indexing mode dispatching
1 parent 448a2f9 commit e8e66ed

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@
4646

4747
#include "integer_advanced_indexing.hpp"
4848

49-
#define INDEXING_MODES 2
50-
#define WRAP_MODE 0
51-
#define CLIP_MODE 1
52-
5349
namespace dpctl
5450
{
5551
namespace tensor
@@ -62,11 +58,17 @@ namespace td_ns = dpctl::tensor::type_dispatch;
6258
using dpctl::tensor::kernels::indexing::put_fn_ptr_t;
6359
using dpctl::tensor::kernels::indexing::take_fn_ptr_t;
6460

65-
static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types]
66-
[td_ns::num_types];
61+
static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types]
62+
[td_ns::num_types];
63+
64+
static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types]
65+
[td_ns::num_types];
66+
67+
static put_fn_ptr_t put_wrap_dispatch_table[INDEXING_MODES][td_ns::num_types]
68+
[td_ns::num_types];
6769

68-
static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][td_ns::num_types]
69-
[td_ns::num_types];
70+
static put_fn_ptr_t put_clip_dispatch_table[INDEXING_MODES][td_ns::num_types]
71+
[td_ns::num_types];
7072

7173
namespace py = pybind11;
7274

@@ -486,7 +488,8 @@ py_take(const dpctl::tensor::usm_ndarray &src,
486488
std::end(pack_deps));
487489
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
488490

489-
auto fn = take_dispatch_table[mode][src_type_id][ind_type_id];
491+
auto fn = mode ? take_wrap_dispatch_table[src_type_id][ind_type_id]
492+
: take_clip_dispatch_table[src_type_id][ind_type_id];
490493

491494
if (fn == nullptr) {
492495
sycl::event::wait(host_task_events);
@@ -755,7 +758,8 @@ py_put(const dpctl::tensor::usm_ndarray &dst,
755758
std::end(pack_deps));
756759
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
757760

758-
auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id];
761+
auto fn = mode ? put_wrap_dispatch_table[src_type_id][ind_type_id]
762+
: put_clip_dispatch_table[src_type_id][ind_type_id];
759763

760764
if (fn == nullptr) {
761765
sycl::event::wait(host_task_events);
@@ -790,20 +794,20 @@ void init_advanced_indexing_dispatch_tables(void)
790794
using dpctl::tensor::kernels::indexing::TakeClipFactory;
791795
DispatchTableBuilder<take_fn_ptr_t, TakeClipFactory, num_types>
792796
dtb_takeclip;
793-
dtb_takeclip.populate_dispatch_table(take_dispatch_table[CLIP_MODE]);
797+
dtb_takeclip.populate_dispatch_table(take_clip_dispatch_table);
794798

795799
using dpctl::tensor::kernels::indexing::TakeWrapFactory;
796800
DispatchTableBuilder<take_fn_ptr_t, TakeWrapFactory, num_types>
797801
dtb_takewrap;
798-
dtb_takewrap.populate_dispatch_table(take_dispatch_table[WRAP_MODE]);
802+
dtb_takewrap.populate_dispatch_table(take_wrap_dispatch_table);
799803

800804
using dpctl::tensor::kernels::indexing::PutClipFactory;
801805
DispatchTableBuilder<put_fn_ptr_t, PutClipFactory, num_types> dtb_putclip;
802-
dtb_putclip.populate_dispatch_table(put_dispatch_table[CLIP_MODE]);
806+
dtb_putclip.populate_dispatch_table(put_clip_dispatch_table);
803807

804808
using dpctl::tensor::kernels::indexing::PutWrapFactory;
805809
DispatchTableBuilder<put_fn_ptr_t, PutWrapFactory, num_types> dtb_putwrap;
806-
dtb_putwrap.populate_dispatch_table(put_dispatch_table[WRAP_MODE]);
810+
dtb_putwrap.populate_dispatch_table(put_wrap_dispatch_table);
807811
}
808812

809813
} // namespace py_internal

0 commit comments

Comments
 (0)