Skip to content

Commit 3ced89a

Browse files
Used Strided1DCyclingIndexer in place implementations
This allows to implement behavior of place which cycles over values of val array if that is shorter than the number of non-zero elements in the mask.
1 parent ed279d6 commit 3ced89a

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//=== boolean_advance_indexing.hpp - ---*-C++-*--/===//
1+
//=== boolean_advance_indexing.hpp - ------*-C++-*--/===//
22
//
33
// Data Parallel Control (dpctl)
44
//
@@ -16,11 +16,11 @@
1616
// See the License for the specific language governing permissions and
1717
// limitations under the License.
1818
//
19-
//===----------------------------------------------------------------------===//
19+
//===---------------------------------------------------------------------===//
2020
///
2121
/// \file
2222
/// This file defines kernels for advanced tensor index operations.
23-
//===----------------------------------------------------------------------===//
23+
//===---------------------------------------------------------------------===//
2424

2525
#pragma once
2626
#include <CL/sycl.hpp>
@@ -114,6 +114,26 @@ struct Strided1DIndexer
114114
py::ssize_t step = 1;
115115
};
116116

117+
struct Strided1DCyclicIndexer
118+
{
119+
Strided1DCyclicIndexer(py::ssize_t _offset,
120+
py::ssize_t _size,
121+
py::ssize_t _step)
122+
: offset(_offset), size(static_cast<size_t>(_size)), step(_step)
123+
{
124+
}
125+
126+
size_t operator()(size_t gid) const
127+
{
128+
return static_cast<size_t>(offset + (gid % size) * step);
129+
}
130+
131+
private:
132+
py::ssize_t offset = 0;
133+
size_t size = 1;
134+
py::ssize_t step = 1;
135+
};
136+
117137
template <typename _IndexerFn> struct ZeroChecker
118138
{
119139

@@ -762,27 +782,22 @@ sycl::event masked_place_all_slices_strided_impl(
762782
py::ssize_t rhs_stride,
763783
const std::vector<sycl::event> &depends = {})
764784
{
765-
// using MaskedPlaceStridedFunctor;
766-
// using Strided1DIndexer;
767-
// using StridedIndexer;
768-
// using TwoZeroOffsets_Indexer;
769-
770785
TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
771786

772787
/* StridedIndexer(int _nd, py::ssize_t _offset, py::ssize_t const
773788
* *_packed_shape_strides) */
774789
StridedIndexer masked_dst_indexer(nd, 0, packed_dst_shape_strides);
775-
Strided1DIndexer masked_rhs_indexer(0, rhs_size, rhs_stride);
790+
Strided1DCyclicIndexer masked_rhs_indexer(0, rhs_size, rhs_stride);
776791

777792
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
778793
cgh.depends_on(depends);
779794

780795
cgh.parallel_for<class masked_place_all_slices_strided_impl_krn<
781-
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, dataT,
782-
indT>>(
796+
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer,
797+
dataT, indT>>(
783798
sycl::range<1>(static_cast<size_t>(iteration_size)),
784799
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
785-
Strided1DIndexer, dataT, indT>(
800+
Strided1DCyclicIndexer, dataT, indT>(
786801
dst_p, cumsum_p, rhs_p, 1, iteration_size,
787802
orthog_dst_rhs_indexer, masked_dst_indexer,
788803
masked_rhs_indexer));
@@ -838,11 +853,6 @@ sycl::event masked_place_some_slices_strided_impl(
838853
py::ssize_t masked_rhs_stride,
839854
const std::vector<sycl::event> &depends = {})
840855
{
841-
// using MaskedPlaceStridedFunctor;
842-
// using Strided1DIndexer;
843-
// using StridedIndexer;
844-
// using TwoOffsets_StridedIndexer;
845-
846856
TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
847857
orthog_nd, ortho_dst_offset, ortho_rhs_offset,
848858
packed_ortho_dst_rhs_shape_strides};
@@ -851,17 +861,18 @@ sycl::event masked_place_some_slices_strided_impl(
851861
* *_packed_shape_strides) */
852862
StridedIndexer masked_dst_indexer{masked_nd, 0,
853863
packed_masked_dst_shape_strides};
854-
Strided1DIndexer masked_rhs_indexer{0, masked_rhs_size, masked_rhs_stride};
864+
Strided1DCyclicIndexer masked_rhs_indexer{0, masked_rhs_size,
865+
masked_rhs_stride};
855866

856867
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
857868
cgh.depends_on(depends);
858869

859870
cgh.parallel_for<class masked_place_some_slices_strided_impl_krn<
860-
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT,
861-
indT>>(
871+
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
872+
dataT, indT>>(
862873
sycl::range<1>(static_cast<size_t>(orthog_nelems * masked_nelems)),
863874
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
864-
Strided1DIndexer, dataT, indT>(
875+
Strided1DCyclicIndexer, dataT, indT>(
865876
dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
866877
orthog_dst_rhs_indexer, masked_dst_indexer,
867878
masked_rhs_indexer));

0 commit comments

Comments
 (0)