Skip to content

Commit 9fd1dd0

Browse files
authored
Fix get item out of range error (#24339) (#24943)
* raise index error when slice out of range; test=develop * add uni test; test=develop * fix format error; test=develop * add comment for py::index_error; test=develop * polish error message; test=develop * polish error message; test=develop
1 parent 5fc4275 commit 9fd1dd0

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

paddle/fluid/pybind/imperative.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,18 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
322322
if (PyCheckInteger(slice_item)) {
323323
// integer, PyLong_AsLong supports both int and long
324324
int start = static_cast<int>(PyLong_AsLong(slice_item));
325+
auto s_t = start;
325326
start = start < 0 ? start + dim_len : start;
327+
if (start >= dim_len) {
328+
std::string str_error_message =
329+
"The starting index " + std::to_string(s_t) +
330+
" of slice is out of bounds in tensor " + std::to_string(dim) +
331+
"-th axis, it shound be in the range of [" +
332+
std::to_string(-dim_len) + ", " + std::to_string(dim_len) + ")";
333+
// py::index_error is corresponding to IndexError in Python
334+
// Used to indicate out of bounds access in __getitem__, __setitem__
335+
throw py::index_error(str_error_message);
336+
}
326337
slice_axes->push_back(dim);
327338
slice_starts->push_back(start);
328339
slice_ends->push_back(start + 1);

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,25 @@ def _test_slice(self):
181181
self.assertTrue(
182182
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
183183

184+
def _test_for_var(self):
185+
np_value = np.random.random((30, 100, 100)).astype('float32')
186+
w = fluid.dygraph.to_variable(np_value)
187+
188+
for i, e in enumerate(w):
189+
self.assertTrue(np.array_equal(e.numpy(), np_value[i]))
190+
184191
def test_slice(self):
185192
with fluid.dygraph.guard():
186193
self._test_slice()
194+
self._test_for_var()
187195

188196
var = fluid.dygraph.to_variable(self.array)
189197
self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))
190198
self.assertTrue(np.array_equal(var[::-1].numpy(), self.array[::-1]))
191199

200+
with self.assertRaises(IndexError):
201+
y = var[self.shape[0]]
202+
192203
def test_var_base_to_np(self):
193204
with fluid.dygraph.guard():
194205
var = fluid.dygraph.to_variable(self.array)

0 commit comments

Comments
 (0)