@@ -221,6 +221,71 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
221
221
return result;
222
222
}
223
223
224
+ static bool PyCheckInteger (PyObject *obj) {
225
+ #if PY_VERSION_HEX < 0x03000000
226
+ return (PyLong_Check (obj) || PyInt_Check (obj)) && !PyBool_Check (obj);
227
+ #else
228
+ return PyLong_Check (obj) && !PyBool_Check (obj);
229
+ #endif
230
+ }
231
+
232
+ // NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
233
+ // https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
234
+ // Original PySlice_GetIndices return wrong result when
235
+ // slice_item contains long int, such as arr[:180L].
236
+ // NOT sure why this happens !!!
237
+ // Besides, PySlice_GetIndices cannot raise error when float in slice item.
238
+ // So, I make a revised version of PySlice_GetIndices, named to
239
+ // _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
240
+ // PySlice_GetIndices in the future.
241
+ static int _PySlice_GetIndices (PySliceObject *r, Py_ssize_t length,
242
+ Py_ssize_t *start, Py_ssize_t *stop,
243
+ Py_ssize_t *step) {
244
+ /* XXX support long ints */
245
+ if (r->step == Py_None) {
246
+ *step = 1 ;
247
+ } else {
248
+ if (PyCheckInteger (r->step )) {
249
+ *step = PyLong_AsLong (r->step );
250
+ } else {
251
+ PADDLE_THROW (platform::errors::InvalidArgument (
252
+ " Currently, VarBase.__getitem__() only allows None or integers in "
253
+ " slice item, but received %s." ,
254
+ std::string (Py_TYPE (r->step )->tp_name )));
255
+ }
256
+ }
257
+ if (r->start == Py_None) {
258
+ *start = *step < 0 ? length - 1 : 0 ;
259
+ } else {
260
+ if (PyCheckInteger (r->start )) {
261
+ *start = PyLong_AsLong (r->start );
262
+ } else {
263
+ PADDLE_THROW (platform::errors::InvalidArgument (
264
+ " Currently, VarBase.__getitem__() only allows None or integers in "
265
+ " slice item, but received %s." ,
266
+ std::string (Py_TYPE (r->start )->tp_name )));
267
+ }
268
+ if (*start < 0 ) *start += length;
269
+ }
270
+ if (r->stop == Py_None) {
271
+ *stop = *step < 0 ? -1 : length;
272
+ } else {
273
+ if (PyCheckInteger (r->stop )) {
274
+ *stop = PyLong_AsLong (r->stop );
275
+ } else {
276
+ PADDLE_THROW (platform::errors::InvalidArgument (
277
+ " Currently, VarBase.__getitem__() only allows None or integers in "
278
+ " slice item, but received %s." ,
279
+ std::string (Py_TYPE (r->stop )->tp_name )));
280
+ }
281
+ if (*stop < 0 ) *stop += length;
282
+ }
283
+ if (*stop > length) return -1 ;
284
+ if (*start >= length) return -1 ;
285
+ if (*step == 0 ) return -1 ;
286
+ return 0 ;
287
+ }
288
+
224
289
static void ParseIndexingSlice (framework::LoDTensor *tensor, PyObject *_index,
225
290
std::vector<int > *slice_axes,
226
291
std::vector<int > *slice_starts,
@@ -245,16 +310,17 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
245
310
" too many indices (%d) for tensor of dimension %d" , size, rank));
246
311
for (int dim = 0 ; dim < size; ++dim) {
247
312
PyObject *slice_item = PyTuple_GetItem (index, dim);
248
- PADDLE_ENFORCE_EQ (
249
- PyNumber_Check (slice_item) || PySlice_Check (slice_item), true ,
250
- platform::errors::InvalidArgument (
251
- " We allow indexing by Integers, Slices, and tuples of "
252
- " these types, but received %s in %dth slice item" ,
253
- std::string (Py_TYPE (slice_item)->tp_name ), dim + 1 ));
313
+ PADDLE_ENFORCE_EQ (PyCheckInteger (slice_item) || PySlice_Check (slice_item),
314
+ true ,
315
+ platform::errors::InvalidArgument (
316
+ " Currently, VarBase.__getitem__() only allows "
317
+ " indexing by Integers, Slices, and tuples of "
318
+ " these types, but received %s in %dth slice item" ,
319
+ std::string (Py_TYPE (slice_item)->tp_name ), dim + 1 ));
254
320
infer_flags->push_back (1 );
255
321
int dim_len = shape[dim];
256
- if (PyNumber_Check (slice_item)) {
257
- // integer
322
+ if (PyCheckInteger (slice_item)) {
323
+ // integer, PyLong_AsLong supports both int and long
258
324
int start = static_cast <int >(PyLong_AsLong (slice_item));
259
325
start = start < 0 ? start + dim_len : start;
260
326
slice_axes->push_back (dim);
@@ -263,17 +329,15 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
263
329
slice_strides->push_back (1 );
264
330
decrease_axis->push_back (dim);
265
331
} else {
266
- // slice
332
+ // slice item
267
333
Py_ssize_t start, end, step;
268
- // The parameter type for the slice parameter was PySliceObject* before 3.2
269
- #if PY_VERSION_HEX >= 0x03020000
270
- PySlice_GetIndices (slice_item, dim_len, &start, &end, &step);
271
- #else
272
- PySlice_GetIndices (reinterpret_cast <PySliceObject *>(slice_item), dim_len,
273
- &start, &end, &step);
274
- #endif
334
+ PySliceObject *p = reinterpret_cast <PySliceObject *>(slice_item);
335
+ _PySlice_GetIndices (p, dim_len, &start, &end, &step);
336
+
275
337
// :: or : or 0:dim_len:1
276
- if (start == 0 && end == dim_len && step == 1 ) continue ;
338
+ if (start == 0 && end == dim_len && step == 1 ) {
339
+ continue ;
340
+ }
277
341
slice_axes->push_back (dim);
278
342
slice_starts->push_back (start);
279
343
slice_ends->push_back (end);
@@ -481,7 +545,6 @@ void BindImperative(py::module *m_ptr) {
481
545
ParseIndexingSlice (tensor, _index.ptr (), &slice_axes,
482
546
&slice_starts, &slice_ends, &slice_strides,
483
547
&decrease_axis, &infer_flags);
484
-
485
548
// release gil and do tracing
486
549
py::gil_scoped_release release;
487
550
const auto &tracer = imperative::GetCurrentTracer ();
@@ -621,8 +684,8 @@ void BindImperative(py::module *m_ptr) {
621
684
[](imperative::VarBase &self,
622
685
const imperative::detail::BackwardStrategy &bckst,
623
686
const imperative::Tracer &tracer) {
624
- // TODO(jiabin): when we impl more backward execution we can select
625
- // them
687
+ // TODO(jiabin): when we impl more backward execution we can
688
+ // select them
626
689
auto *engine = tracer.GetEngine ();
627
690
engine->Init (&self, bckst);
628
691
VLOG (3 ) << " Start backward" ;
0 commit comments