1
- // === boolean_advance_indexing.hpp - ---*-C++-*--/===//
1
+ // === boolean_advance_indexing.hpp - --- ---*-C++-*--/===//
2
2
//
3
3
// Data Parallel Control (dpctl)
4
4
//
16
16
// See the License for the specific language governing permissions and
17
17
// limitations under the License.
18
18
//
19
- // ===---------------------------------------------------------------------- ===//
19
+ // ===---------------------------------------------------------------------===//
20
20
// /
21
21
// / \file
22
22
// / This file defines kernels for advanced tensor index operations.
23
- // ===---------------------------------------------------------------------- ===//
23
+ // ===---------------------------------------------------------------------===//
24
24
25
25
#pragma once
26
26
#include < CL/sycl.hpp>
@@ -114,6 +114,26 @@ struct Strided1DIndexer
114
114
py::ssize_t step = 1 ;
115
115
};
116
116
117
+ struct Strided1DCyclicIndexer
118
+ {
119
+ Strided1DCyclicIndexer (py::ssize_t _offset,
120
+ py::ssize_t _size,
121
+ py::ssize_t _step)
122
+ : offset(_offset), size(static_cast <size_t >(_size)), step(_step)
123
+ {
124
+ }
125
+
126
+ size_t operator ()(size_t gid) const
127
+ {
128
+ return static_cast <size_t >(offset + (gid % size) * step);
129
+ }
130
+
131
+ private:
132
+ py::ssize_t offset = 0 ;
133
+ size_t size = 1 ;
134
+ py::ssize_t step = 1 ;
135
+ };
136
+
117
137
template <typename _IndexerFn> struct ZeroChecker
118
138
{
119
139
@@ -762,27 +782,22 @@ sycl::event masked_place_all_slices_strided_impl(
762
782
py::ssize_t rhs_stride,
763
783
const std::vector<sycl::event> &depends = {})
764
784
{
765
- // using MaskedPlaceStridedFunctor;
766
- // using Strided1DIndexer;
767
- // using StridedIndexer;
768
- // using TwoZeroOffsets_Indexer;
769
-
770
785
TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
771
786
772
787
/* StridedIndexer(int _nd, py::ssize_t _offset, py::ssize_t const
773
788
* *_packed_shape_strides) */
774
789
StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
775
- Strided1DIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
790
+ Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
776
791
777
792
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
778
793
cgh.depends_on (depends);
779
794
780
795
cgh.parallel_for <class masked_place_all_slices_strided_impl_krn <
781
- TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, dataT ,
782
- indT>>(
796
+ TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer ,
797
+ dataT, indT>>(
783
798
sycl::range<1 >(static_cast <size_t >(iteration_size)),
784
799
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
785
- Strided1DIndexer , dataT, indT>(
800
+ Strided1DCyclicIndexer , dataT, indT>(
786
801
dst_p, cumsum_p, rhs_p, 1 , iteration_size,
787
802
orthog_dst_rhs_indexer, masked_dst_indexer,
788
803
masked_rhs_indexer));
@@ -838,11 +853,6 @@ sycl::event masked_place_some_slices_strided_impl(
838
853
py::ssize_t masked_rhs_stride,
839
854
const std::vector<sycl::event> &depends = {})
840
855
{
841
- // using MaskedPlaceStridedFunctor;
842
- // using Strided1DIndexer;
843
- // using StridedIndexer;
844
- // using TwoOffsets_StridedIndexer;
845
-
846
856
TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
847
857
orthog_nd, ortho_dst_offset, ortho_rhs_offset,
848
858
packed_ortho_dst_rhs_shape_strides};
@@ -851,17 +861,18 @@ sycl::event masked_place_some_slices_strided_impl(
851
861
* *_packed_shape_strides) */
852
862
StridedIndexer masked_dst_indexer{masked_nd, 0 ,
853
863
packed_masked_dst_shape_strides};
854
- Strided1DIndexer masked_rhs_indexer{0 , masked_rhs_size, masked_rhs_stride};
864
+ Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
865
+ masked_rhs_stride};
855
866
856
867
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
857
868
cgh.depends_on (depends);
858
869
859
870
cgh.parallel_for <class masked_place_some_slices_strided_impl_krn <
860
- TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT ,
861
- indT>>(
871
+ TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer ,
872
+ dataT, indT>>(
862
873
sycl::range<1 >(static_cast <size_t >(orthog_nelems * masked_nelems)),
863
874
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
864
- Strided1DIndexer , dataT, indT>(
875
+ Strided1DCyclicIndexer , dataT, indT>(
865
876
dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
866
877
orthog_dst_rhs_indexer, masked_dst_indexer,
867
878
masked_rhs_indexer));
0 commit comments