Skip to content

Commit e6180bc

Browse files
fix slice endpoints computation
1 parent 6305ed9 commit e6180bc

File tree

1 file changed

+29
-18
lines changed

1 file changed

+29
-18
lines changed

backends/mlu/kernels/slice_kernel.cc

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -230,28 +230,39 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims,
230230
continue;
231231
}
232232

233-
T start, end;
234-
bool dummy_zero_out_dim = false;
235-
normalize_interval((*starts)[i],
236-
(*ends)[i],
237-
step,
238-
dim_value,
239-
&start,
240-
&end,
241-
&dummy_zero_out_dim);
242-
if (end == -dim_value - 1) {
243-
end = -1;
244-
}
233+
T dim_value = in_dims[axis];
234+
235+
if (dim_value > 0) {
236+
T step = steps == nullptr ? 1 : (*steps)[i];
237+
PADDLE_ENFORCE_NE(
238+
step,
239+
0,
240+
common::errors::InvalidArgument(
241+
"Step should not be 0, but received step = %d.", step));
242+
243+
T start, end;
244+
bool dummy_zero_out_dim = false;
245+
normalize_interval((*starts)[i],
246+
(*ends)[i],
247+
step,
248+
dim_value,
249+
&start,
250+
&end,
251+
&dummy_zero_out_dim);
252+
if (end == -dim_value - 1) {
253+
end = -1;
254+
}
245255

246-
(*starts)[i] = start;
247-
(*ends)[i] = end;
248-
} else if (dim_value == 0) {
249-
(*starts)[i] = 0;
250-
(*ends)[i] = 0;
256+
(*starts)[i] = start;
257+
(*ends)[i] = end;
258+
} else if (dim_value == 0) {
259+
(*starts)[i] = 0;
260+
(*ends)[i] = 0;
261+
}
251262
}
252263
}
253264

254-
} // custom_kernel
265+
} // namespace custom_kernel
255266

256267
template <typename T = int64_t>
257268
inline phi::DDim GetSliceDims(const phi::DDim in_dims,

0 commit comments

Comments
 (0)