Skip to content

Commit 477a1ee

Browse files
authored
[cherry-pick] fix bug of varbase.__getitem__, test=develop (#24647)
* fix bug of varbase.__getitem__, test=develop (#24642) * fix bug of varbase.__getitem__, test=develop * fix bug of float and other type, test=develop * fix unittest, test=develop
1 parent d84a30b commit 477a1ee

File tree

2 files changed

+133
-20
lines changed

2 files changed

+133
-20
lines changed

paddle/fluid/pybind/imperative.cc

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,71 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
221221
return result;
222222
}
223223

224+
static bool PyCheckInteger(PyObject *obj) {
225+
#if PY_VERSION_HEX < 0x03000000
226+
return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
227+
#else
228+
return PyLong_Check(obj) && !PyBool_Check(obj);
229+
#endif
230+
}
231+
232+
// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
233+
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
234+
// Original PySlice_GetIndices return wrong result when
235+
// slice_item contains long int, such as arr[:180L].
236+
// NOT sure why this happens !!!
237+
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
238+
// So, I make a revised version of PySlice_GetIndices, named to
239+
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
240+
// PySlice_GetIndices in the future.
241+
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
242+
Py_ssize_t *start, Py_ssize_t *stop,
243+
Py_ssize_t *step) {
244+
/* XXX support long ints */
245+
if (r->step == Py_None) {
246+
*step = 1;
247+
} else {
248+
if (PyCheckInteger(r->step)) {
249+
*step = PyLong_AsLong(r->step);
250+
} else {
251+
PADDLE_THROW(platform::errors::InvalidArgument(
252+
"Currently, VarBase.__getitem__() only allows None or integers in "
253+
"slice item, but received %s.",
254+
std::string(Py_TYPE(r->step)->tp_name)));
255+
}
256+
}
257+
if (r->start == Py_None) {
258+
*start = *step < 0 ? length - 1 : 0;
259+
} else {
260+
if (PyCheckInteger(r->start)) {
261+
*start = PyLong_AsLong(r->start);
262+
} else {
263+
PADDLE_THROW(platform::errors::InvalidArgument(
264+
"Currently, VarBase.__getitem__() only allows None or integers in "
265+
"slice item, but received %s.",
266+
std::string(Py_TYPE(r->start)->tp_name)));
267+
}
268+
if (*start < 0) *start += length;
269+
}
270+
if (r->stop == Py_None) {
271+
*stop = *step < 0 ? -1 : length;
272+
} else {
273+
if (PyCheckInteger(r->stop)) {
274+
*stop = PyLong_AsLong(r->stop);
275+
} else {
276+
PADDLE_THROW(platform::errors::InvalidArgument(
277+
"Currently, VarBase.__getitem__() only allows None or integers in "
278+
"slice item, but received %s.",
279+
std::string(Py_TYPE(r->stop)->tp_name)));
280+
}
281+
if (*stop < 0) *stop += length;
282+
}
283+
if (*stop > length) return -1;
284+
if (*start >= length) return -1;
285+
if (*step == 0) return -1;
286+
return 0;
287+
}
288+
224289
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
225290
std::vector<int> *slice_axes,
226291
std::vector<int> *slice_starts,
@@ -245,16 +310,17 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
245310
"too many indices (%d) for tensor of dimension %d", size, rank));
246311
for (int dim = 0; dim < size; ++dim) {
247312
PyObject *slice_item = PyTuple_GetItem(index, dim);
248-
PADDLE_ENFORCE_EQ(
249-
PyNumber_Check(slice_item) || PySlice_Check(slice_item), true,
250-
platform::errors::InvalidArgument(
251-
"We allow indexing by Integers, Slices, and tuples of "
252-
"these types, but received %s in %dth slice item",
253-
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
313+
PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
314+
true,
315+
platform::errors::InvalidArgument(
316+
"Currently, VarBase.__getitem__() only allows "
317+
"indexing by Integers, Slices, and tuples of "
318+
"these types, but received %s in %dth slice item",
319+
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
254320
infer_flags->push_back(1);
255321
int dim_len = shape[dim];
256-
if (PyNumber_Check(slice_item)) {
257-
// integer
322+
if (PyCheckInteger(slice_item)) {
323+
// integer, PyLong_AsLong supports both int and long
258324
int start = static_cast<int>(PyLong_AsLong(slice_item));
259325
start = start < 0 ? start + dim_len : start;
260326
slice_axes->push_back(dim);
@@ -263,17 +329,15 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
263329
slice_strides->push_back(1);
264330
decrease_axis->push_back(dim);
265331
} else {
266-
// slice
332+
// slice item
267333
Py_ssize_t start, end, step;
268-
// The parameter type for the slice parameter was PySliceObject* before 3.2
269-
#if PY_VERSION_HEX >= 0x03020000
270-
PySlice_GetIndices(slice_item, dim_len, &start, &end, &step);
271-
#else
272-
PySlice_GetIndices(reinterpret_cast<PySliceObject *>(slice_item), dim_len,
273-
&start, &end, &step);
274-
#endif
334+
PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
335+
_PySlice_GetIndices(p, dim_len, &start, &end, &step);
336+
275337
// :: or : or 0:dim_len:1
276-
if (start == 0 && end == dim_len && step == 1) continue;
338+
if (start == 0 && end == dim_len && step == 1) {
339+
continue;
340+
}
277341
slice_axes->push_back(dim);
278342
slice_starts->push_back(start);
279343
slice_ends->push_back(end);
@@ -481,7 +545,6 @@ void BindImperative(py::module *m_ptr) {
481545
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
482546
&slice_starts, &slice_ends, &slice_strides,
483547
&decrease_axis, &infer_flags);
484-
485548
// release gil and do tracing
486549
py::gil_scoped_release release;
487550
const auto &tracer = imperative::GetCurrentTracer();
@@ -621,8 +684,8 @@ void BindImperative(py::module *m_ptr) {
621684
[](imperative::VarBase &self,
622685
const imperative::detail::BackwardStrategy &bckst,
623686
const imperative::Tracer &tracer) {
624-
// TODO(jiabin): when we impl more backward execution we can select
625-
// them
687+
// TODO(jiabin): when we impl more backward execution we can
688+
// select them
626689
auto *engine = tracer.GetEngine();
627690
engine->Init(&self, bckst);
628691
VLOG(3) << "Start backward";
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
import paddle.fluid as fluid
18+
19+
20+
class TestImperativeVarBaseGetItem(unittest.TestCase):
21+
def test_getitem_with_long(self):
22+
with fluid.dygraph.guard():
23+
data = np.random.random((2, 80, 16128)).astype('float32')
24+
var = fluid.dygraph.to_variable(data)
25+
sliced = var[:, 10:, :var.shape[1]] # var.shape[1] is 80L here
26+
self.assertEqual(sliced.shape, [2, 70, 80])
27+
28+
sliced = var[:, var.shape[0]:, var.shape[0]:var.shape[1]]
29+
self.assertEqual(sliced.shape, [2, 78, 78])
30+
31+
def test_getitem_with_float(self):
32+
def test_float_in_slice_item():
33+
with fluid.dygraph.guard():
34+
data = np.random.random((2, 80, 16128)).astype('float32')
35+
var = fluid.dygraph.to_variable(data)
36+
sliced = var[:, 1.1:, :var.shape[1]]
37+
38+
self.assertRaises(Exception, test_float_in_slice_item)
39+
40+
def test_float_in_index():
41+
with fluid.dygraph.guard():
42+
data = np.random.random((2, 80, 16128)).astype('float32')
43+
var = fluid.dygraph.to_variable(data)
44+
sliced = var[1.1]
45+
46+
self.assertRaises(Exception, test_float_in_index)
47+
48+
49+
if __name__ == '__main__':
50+
unittest.main()

0 commit comments

Comments
 (0)