@@ -23,6 +23,164 @@ namespace phi {
2323
2424namespace funcs {
2525
26+ /* *
27+ * @brief Normalizes the slice interval [st, ed) with a given step and dimension
28+ * size.
29+ *
30+ * This function adjusts the interval [st, ed) to fit within the bounds defined
31+ * by the dimension size, taking into account the specified step. It handles
32+ * both positive and negative steps and accounts for negative indices by
33+ * converting them to equivalent positive indices within the dimension size.
34+ *
35+ * @tparam T The data type of the input parameters, which can be an integer or
36+ * floating-point type.
37+ * @param st The starting index of the interval.
38+ * @param ed The ending index of the interval (exclusive).
39+ * @param step The step size for iterating through the interval, which can be
40+ * positive or negative.
41+ * @param dim_size The size of the dimension, serving as the upper bound for
42+ * valid indices.
43+ * @param st_out Pointer to store the normalized starting index.
44+ * @param ed_out Pointer to store the normalized ending index.
45+ * @param zero_dim_out Pointer to a boolean flag that is set to true if the
46+ * resulting interval is empty.
47+ *
48+ * @details
49+ * - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be
50+ * within the range [0, dim_size).
51+ * - If `step < 0`, the function adjusts `st` and `ed` to accommodate the
52+ * reverse traversal of the interval.
53+ * - Handles special cases where `st` and `ed` may be out of bounds or where
54+ * `dim_size` is zero.
55+ * - Uses pointer parameters for output to modify the values directly.
56+ * - The function also handles scenarios involving negative indices, converting
57+ * them appropriately.
58+ *
59+ * @example
60+ * T st_out, ed_out;
61+ * bool zero_dim;
62+ * normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim);
63+ * // Results in: st_out = 1, ed_out = 2, zero_dim = false
64+ *
65+ * @note The function assumes that the pointers provided for output parameters
66+ * are valid and non-null.
67+ */
68+ template <typename T>
69+ void normalize_interval (
70+ T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool * zero_dim_out) {
71+ /* Normalize slice interval [st, ed) with given step and dim_size.
72+ e.g. if given st = -3, ed = -2, step = 1, dim_size = 4,
73+ then normalized st_out = 1(-3+4), st_ed = 2(-2+4).
74+
75+ This function is general enough and applicable
76+ for both step > 0 and step < 0 scenarios.
77+
78+ Indicices dipicted as below:
79+
80+ ===============================================================
81+ | 0 1 2 3 ... D-1 | D D+1 ...
82+ ... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 |
83+ ===============================================================
84+ */
85+ // 0 dim size, just return
86+ if (dim_size <= 0 ) {
87+ *st_out = *ed_out = 0 ;
88+ *zero_dim_out = true ;
89+ return ;
90+ }
91+
92+ if (step > 0 ) {
93+ /* positive step */
94+ // 0 dim size case 1
95+ if (st >= dim_size) {
96+ *st_out = *ed_out = 0 ;
97+ *zero_dim_out = true ;
98+ return ;
99+ }
100+
101+ // 0 dim size case 2
102+ if (ed <= -dim_size) {
103+ *st_out = *ed_out = 0 ;
104+ *zero_dim_out = true ;
105+ return ;
106+ }
107+
108+ // make st belongs: (-inf, -D-1)∪[0, D)
109+ if (-dim_size <= st && st < 0 ) {
110+ st += dim_size;
111+ }
112+ // make st belongs: [0, D)
113+ st = std::max (st, static_cast <T>(0 ));
114+
115+ // make ed belongs: [0, +inf)
116+ if (-dim_size <= ed && ed < 0 ) {
117+ ed += dim_size;
118+ }
119+ // make ed belongs: [0, D]
120+ ed = std::min (ed, dim_size);
121+
122+ // 0 dim size case 3
123+ if (st >= ed) {
124+ *st_out = *ed_out = 0 ;
125+ *zero_dim_out = true ;
126+ return ;
127+ }
128+ *st_out = st;
129+ *ed_out = ed;
130+ return ;
131+
132+ } else {
133+ /* negative step */
134+ // 0 dim size case 1
135+ if (st <= -dim_size - 1 ) {
136+ *st_out = *ed_out = 0 ;
137+ *zero_dim_out = true ;
138+ return ;
139+ }
140+
141+ // 0 dim size case 2
142+ if (ed >= dim_size - 1 ) {
143+ *st_out = *ed_out = 0 ;
144+ *zero_dim_out = true ;
145+ return ;
146+ }
147+
148+ // make st belongs: [0, D)∪[0, +inf)
149+ if (-dim_size <= st && st < 0 ) {
150+ st += dim_size;
151+ }
152+ // make st belongs: [0, D)
153+ st = std::min (st, dim_size - 1 );
154+
155+ // make ed belongs: [-inf, -D)∪[0, D)
156+ if (-dim_size <= ed && ed < 0 ) {
157+ ed += dim_size;
158+ }
159+ // make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D)
160+ ed = std::max (ed, -dim_size - 1 );
161+
162+ if (ed == -dim_size - 1 ) {
163+ // When ed=-D-1, it is symmetrical to when step is greater than 0 and
164+ // ed=D.
165+ *st_out = st;
166+ *ed_out = ed;
167+ return ;
168+ }
169+
170+ // now only remain the case that ed belongs to: [0, D)
171+ // 0 dim size case 3
172+ if (ed >= st) {
173+ *st_out = *ed_out = 0 ;
174+ *zero_dim_out = true ;
175+ return ;
176+ }
177+
178+ *st_out = st;
179+ *ed_out = ed;
180+ return ;
181+ }
182+ }
183+
26184template <typename T = int64_t >
27185inline void CheckAndUpdateSliceAttrs (const DDim in_dims,
28186 const std::vector<T>& axes,
@@ -56,41 +214,17 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
56214 common::errors::InvalidArgument (
57215 " Step should not be 0, but received step = %d." , step));
58216
59- T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
60- start = std::max (start, static_cast <T>(0 ));
61-
62- T end =
63- 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
64- end = std::min (end, dim_value);
65-
66- if (step > 0 ) {
67- start = std::min (start, dim_value);
68- end = std::max (end, static_cast <T>(0 ));
69- PADDLE_ENFORCE_GE (
70- end,
71- start,
72- common::errors::InvalidArgument (
73- " When step > 0, end should be greater than start, but "
74- " received end = %d, start = %d." ,
75- end,
76- start));
77- } else {
78- // NOTE(liym27): When step < 0, start should less and equal to
79- // dim_value-1
80- // "end is -1" means contain the 0-th element of this axis.
81- start = std::min (start, dim_value - 1 );
82- if (end < -1 ) {
83- end += dim_value;
84- }
85- end = std::max (end, static_cast <T>(-1 ));
86- PADDLE_ENFORCE_GE (
87- start,
88- end,
89- common::errors::InvalidArgument (
90- " When step < 0, start should be greater than end, but "
91- " received start = %d, end = %d." ,
92- start,
93- end));
217+ T start, end;
218+ bool dummy_zero_out_dim = false ;
219+ normalize_interval ((*starts)[i],
220+ (*ends)[i],
221+ step,
222+ dim_value,
223+ &start,
224+ &end,
225+ &dummy_zero_out_dim);
226+ if (end == -dim_value - 1 ) {
227+ end = -1 ;
94228 }
95229
96230 (*starts)[i] = start;
@@ -117,24 +251,17 @@ inline void UpdateSliceAttrs(const DDim in_dims,
117251 T dim_value = in_dims[axis];
118252 if (dim_value > 0 ) {
119253 T step = steps == nullptr ? 1 : (*steps)[i];
120- T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
121- start = std::max (start, static_cast <T>(0 ));
122- T end =
123- 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
124- end = std::min (end, dim_value);
125-
126- if (step > 0 ) {
127- start = std::min (start, dim_value);
128- end = std::max (end, static_cast <T>(0 ));
129- } else {
130- // NOTE: When step < 0, start should less and equal to
131- // dim_value-1
132- // "end is -1" means contain the 0-th element of this axis.
133- start = std::min (start, dim_value - 1 );
134- if (end < -1 ) {
135- end += dim_value;
136- }
137- end = std::max (end, static_cast <T>(-1 ));
254+ T start = (*starts)[i];
255+ T end = (*ends)[i];
256+
257+ bool dummy_zero_out_dim = false ;
258+ normalize_interval (
259+ start, end, step, dim_value, &start, &end, &dummy_zero_out_dim);
260+
261+ // manually set the end to -1 when step < 0,
262+ // which indicates that it can extend to the left endpoint.
263+ if (end == -dim_value - 1 && step < 0 ) {
264+ end = -1 ;
138265 }
139266 (*starts)[i] = start;
140267 (*ends)[i] = end;
0 commit comments