@@ -127,13 +127,8 @@ struct DeferredGather
127127 void *outPtr = nullptr ;
128128 py::handle res;
129129 if (!sendonly || !trscvr) {
130- auto tmp = a_ptr->shape ();
131- std::vector<ssize_t > tmpv (tmp, &tmp[a_ptr->ndims ()]);
132- // numpy treats 0d arrays as empty arrays, not as a scalar as we do
133- if (tmpv.empty ()) {
134- tmpv.emplace_back (1 );
135- }
136- res = dispatch<mk_array>(a_ptr->dtype (), std::move (tmpv), outPtr);
130+ std::vector<ssize_t > shp (a_ptr->shape ());
131+ res = dispatch<mk_array>(a_ptr->dtype (), std::move (shp), outPtr);
137132 }
138133
139134 gather_array (a_ptr, _root, outPtr);
@@ -309,9 +304,37 @@ struct DeferredGetItem : public Deferred {
309304
310305// ***************************************************************************
311306
307+ // extract "start", "stop", "step" int attrs from py::slice
308+ std::optional<int > getSliceAttr (const py::slice &slice, const char *name) {
309+ auto obj = getattr (slice, name);
310+ if (py::isinstance<py::none>(obj)) {
311+ return std::nullopt ;
312+ } else if (py::isinstance<py::int_>(obj)) {
313+ return std::optional<int >{obj.cast <int >()};
314+ } else {
315+ throw std::invalid_argument (" Invalid indices" );
316+ }
317+ };
318+
319+ // check that multi-dimensional slice start does not exceed given shape
320+ void validateSlice (const shape_type &shape,
321+ const std::vector<py::slice> &slices) {
322+ auto dim = shape.size ();
323+ for (std::size_t i = 0 ; i < dim; i++) {
324+ auto start = getSliceAttr (slices[i], " start" );
325+ if (start && start.value () >= shape[i]) {
326+ std::stringstream msg;
327+ msg << " index " << start.value () << " is out of bounds for axis " << i
328+ << " with size " << shape[i] << " \n " ;
329+ throw std::out_of_range (msg.str ());
330+ }
331+ }
332+ }
333+
312334FutureArray *GetItem::__getitem__ (const FutureArray &a,
313335 const std::vector<py::slice> &v) {
314336 auto afut = a.get ();
337+ validateSlice (afut.shape (), v);
315338 NDSlice slc (v, afut.shape ());
316339 return new FutureArray (defer<DeferredGetItem>(afut, std::move (slc)));
317340}
@@ -328,9 +351,10 @@ GetItem::py_future_type GetItem::gather(const FutureArray &a, rank_type root) {
328351FutureArray *SetItem::__setitem__ (FutureArray &a,
329352 const std::vector<py::slice> &v,
330353 const py::object &b) {
331- auto bb =
332- Creator::mk_future (b, a.get ().device (), a.get ().team (), a.get ().dtype ());
333- a.put (defer<DeferredSetItem>(a.get (), bb.first ->get (), v));
354+ auto afut = a.get ();
355+ validateSlice (afut.shape (), v);
356+ auto bb = Creator::mk_future (b, afut.device (), afut.team (), afut.dtype ());
357+ a.put (defer<DeferredSetItem>(afut, bb.first ->get (), v));
334358 if (bb.second )
335359 delete bb.first ;
336360 return &a;
0 commit comments