Skip to content

Commit fc46303

Browse files
Renamed advance_indexing.*pp into integer_advanced_indexing.*pp
Streamlined call operator implementation for projection classes. Added missing includes.
1 parent 0cf7ba4 commit fc46303

File tree

5 files changed

+22
-18
lines changed

5 files changed

+22
-18
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pybind11_add_module(${python_module_name} MODULE
3131
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
3333
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
34-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/advanced_indexing.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
3535
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
3636
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp

dpctl/tensor/libtensor/include/kernels/advanced_indexing.hpp renamed to dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "utils/strided_iters.hpp"
2727
#include "utils/type_utils.hpp"
2828
#include <CL/sycl.hpp>
29+
#include <algorithm>
2930
#include <complex>
3031
#include <cstdint>
3132
#include <pybind11/pybind11.h>
@@ -52,10 +53,9 @@ class ClipIndex
5253

5354
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
5455
{
55-
max_item = (max_item > 0) ? max_item : 1;
56-
ind = (ind < 0) ? (ind <= -max_item) ? (0) : (ind + max_item)
57-
: (ind >= max_item) ? (max_item - 1)
58-
: ind;
56+
max_item = std::max<py::ssize_t>(max_item, 1);
57+
ind = std::clamp<py::ssize_t>(ind, -max_item, max_item - 1);
58+
ind = (ind < 0) ? ind + max_item : ind;
5959
return;
6060
}
6161
};
@@ -67,10 +67,8 @@ class WrapIndex
6767

6868
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
6969
{
70-
max_item = (max_item > 0) ? max_item : 1;
71-
ind = (ind < 0) ? max_item - (-ind % max_item)
72-
: (ind >= max_item) ? ind % max_item
73-
: ind;
70+
max_item = std::max<py::ssize_t>(max_item, 1);
71+
ind = ind % max_item;
7472
return;
7573
}
7674
};
@@ -136,9 +134,9 @@ template <typename ProjectorT, typename T, typename indT> class TakeFunctor
136134
dst_orthog_idx);
137135

138136
ProjectorT proj{};
139-
py::ssize_t ind_arr_idx(0);
140137
CIndexer_vector<py::ssize_t> ind_indxr(ind_nd_);
141138
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
139+
py::ssize_t ind_arr_idx(0);
142140
ind_indxr.get_displacement<const py::ssize_t *>(
143141
static_cast<py::ssize_t>(i_along), ind_shape_and_strides_,
144142
ind_shape_and_strides_ + ((axis_idx + 1) * ind_nd_),

dpctl/tensor/libtensor/source/advanced_indexing.cpp renamed to dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,19 @@
2727
#include <algorithm>
2828
#include <complex>
2929
#include <cstdint>
30+
#include <iostream>
3031
#include <pybind11/complex.h>
3132
#include <pybind11/pybind11.h>
3233
#include <pybind11/stl.h>
3334
#include <utility>
3435

3536
#include "dpctl4pybind11.hpp"
36-
#include "kernels/advanced_indexing.hpp"
37+
#include "kernels/integer_advanced_indexing.hpp"
3738
#include "utils/type_dispatch.hpp"
3839
#include "utils/type_utils.hpp"
3940

41+
#include "integer_advanced_indexing.hpp"
42+
4043
#define INDEXING_MODES 2
4144
#define CLIP_MODE 0
4245
#define WRAP_MODE 1
@@ -85,8 +88,8 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
8588
int arr_nd)
8689
{
8790

88-
int orthog_sh_elems = ((inp_nd - k) > 1) ? (inp_nd - k) : 1;
89-
int along_sh_elems = (ind_nd > 1) ? ind_nd : 1;
91+
int orthog_sh_elems = std::max<int>(inp_nd - k, 1);
92+
int along_sh_elems = std::max<int>(ind_nd, 1);
9093

9194
using usm_host_allocatorT =
9295
sycl::usm_allocator<py::ssize_t, sycl::usm::alloc::host>;
@@ -284,7 +287,7 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
284287
int axis_start,
285288
uint8_t mode,
286289
sycl::queue exec_q,
287-
const std::vector<sycl::event> &depends = {})
290+
const std::vector<sycl::event> &depends)
288291
{
289292
int k = ind.size();
290293

@@ -328,7 +331,7 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
328331
const py::ssize_t *src_shape = src.get_shape_raw();
329332
const py::ssize_t *dst_shape = dst.get_shape_raw();
330333

331-
int orthog_nd = ((src_nd - k) > 0) ? src_nd - k : 1;
334+
int orthog_nd = std::max<int>(src_nd - k, 1);
332335

333336
bool orthog_shapes_equal(true);
334337
size_t orthog_nelems(1);
@@ -412,7 +415,7 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
412415
}
413416
}
414417

415-
auto ind_sh_elems = (ind_nd > 0) ? ind_nd : 1;
418+
int ind_sh_elems = std::max<int>(ind_nd, 1);
416419

417420
std::vector<char *> ind_ptrs;
418421
ind_ptrs.reserve(k);
@@ -633,12 +636,15 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
633636
std::to_string(ind_type_id));
634637
}
635638

639+
std::cout << "Submitting take" << std::endl;
636640
sycl::event take_generic_ev =
637641
fn(exec_q, orthog_nelems, ind_nelems, orthog_nd, ind_nd, k,
638642
packed_shapes_strides, packed_axes_shapes_strides,
639643
packed_ind_shapes_strides, src_data, dst_data, packed_ind_ptrs,
640644
src_offset, dst_offset, packed_ind_offsets, all_deps);
641645

646+
std::cout << "Submitting take clean-up host task" << std::endl;
647+
642648
// free packed temporaries
643649
auto ctx = exec_q.get_context();
644650
exec_q.submit([&](sycl::handler &cgh) {
@@ -666,7 +672,7 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
666672
int axis_start,
667673
uint8_t mode,
668674
sycl::queue exec_q,
669-
const std::vector<sycl::event> &depends = {})
675+
const std::vector<sycl::event> &depends)
670676
{
671677
int k = ind.size();
672678

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@
3333

3434
#include "dpctl4pybind11.hpp"
3535

36-
#include "advanced_indexing.hpp"
3736
#include "copy_and_cast_usm_to_usm.hpp"
3837
#include "copy_for_reshape.hpp"
3938
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
4039
#include "device_support_queries.hpp"
4140
#include "eye_ctor.hpp"
4241
#include "full_ctor.hpp"
42+
#include "integer_advanced_indexing.hpp"
4343
#include "linear_sequences.hpp"
4444
#include "triul_ctor.hpp"
4545
#include "utils/strided_iters.hpp"

0 commit comments

Comments
 (0)