@@ -176,43 +176,60 @@ Tensor Tensor::view(svector<Range> ranges) const
176176 while (ranges.size () != dimentions ())
177177 ranges.push_back (all ());
178178
179- auto resolve_index = [](intmax_t idx, bool from_back, intmax_t size) {
180- if (from_back == true )
179+ auto resolve_index = [](intmax_t idx, intmax_t size) -> intmax_t {
180+ if (idx < 0 )
181181 return size-idx;
182- else
183- return idx;
182+ return idx;
184183 };
185184
186- auto resolve_range_size = [resolve_index](Range r, intmax_t size) {
187- return resolve_index (r.end (), r.endFromBack (), size) - resolve_index (r.start (), r.startFromBack (), size);
185+ auto is_index_valid = [](intmax_t idx, intmax_t size) -> bool {
186+ if (idx >= 0 )
187+ return idx < size;
188+ return -idx <= size;
188189 };
189190
190191 Shape result_shape;
191192 svector<intmax_t > offset;
193+ Shape viewed_strides = pimpl_->stride ();
194+ offset.reserve (dimentions ());
192195
193- for (size_t i=0 ;i<dimentions ();i++) {
194- Range r = ranges[i];
195-
196- intmax_t start = resolve_index (r.start (), r.startFromBack (), shape ()[i]);
197- intmax_t size = resolve_range_size (r, shape ()[i]);
198-
199- if (size < 0 )
200- throw EtError (" Negative steps not supported now" );
201- if (start < 0 || (start+size) > shape ()[i])
202- throw EtError (" Indexing from " + std::to_string (start+size-1 ) + " is out of the range of " + std::to_string (shape ()[i]));
196+ assert (viewed_strides.size () == dimentions ());
203197
204- offset.push_back (start);
205- if (size != 1 || result_shape.size () != 0 ) // Ignore heading 1 dimentions
198+ for (size_t i=0 ;i<dimentions ();i++) {
199+ const Range& r = ranges[i];
200+ intmax_t dim_size = shape ()[i];
201+
202+ intmax_t start = r.start ().value_or (0 );
203+ intmax_t stop = r.stop ().value_or (dim_size);
204+ intmax_t step = r.step ().value_or (1 );
205+
206+ // Indexing validations
207+ if (step == 0 )
208+ throw EtError (" Error: Step size is zero in dimension " + std::to_string (i));
209+ if (is_index_valid (start, dim_size) == false )
210+ throw EtError (" Starting index " + std::to_string (start) + " is out of range in dimension " + std::to_string (i));
211+ if (is_index_valid (stop, dim_size+1 ) == false )
212+ throw EtError (" Stopping index " + std::to_string (stop) + " is out of range in dimension " + std::to_string (i));
213+
214+ intmax_t real_start = resolve_index (start, dim_size);
215+ intmax_t real_stop = resolve_index (stop, dim_size);
216+ intmax_t size = (real_stop - real_start - 1 ) / step + 1 ;
217+
218+ if ((real_stop - real_start) * step < 0 )
219+ throw EtError (" Step is going in the wrong direction. Will cause infinate loop" );
220+ viewed_strides[i] *= step;
221+
222+ offset.push_back (real_start);
223+ if (size != 1 || result_shape.empty () == false ) // Ignore heading 1 dimentions
206224 result_shape.push_back (size);
207225 }
208226
209227 // If all dims are 1, thus no shape. Give it a shape
210- if (result_shape.size () == 0 )
228+ if (result_shape.empty () == true )
211229 result_shape.push_back (1 );
212230
213- Shape view_meta_strides = pimpl_->stride ();
214231 size_t initial_offset = unfold (offset, pimpl_->stride ())+pimpl_->offset ();
215- return std::make_shared<TensorImpl>(pimpl_->buffer (), result_shape, view_meta_strides , initial_offset);
232+ return std::make_shared<TensorImpl>(pimpl_->buffer (), result_shape, viewed_strides , initial_offset);
216233}
217234
218235Tensor et::zeros (const Shape& shape, DType dtype, Backend* backend)
@@ -364,7 +381,7 @@ inline Shape brodcast_result_shape(Shape a, Shape b)
364381Tensor et::brodcast_to (const Tensor& t, Shape s)
365382{
366383 et_assert (s.size () >= t.dimentions ());
367- Shape stride = leftpad (shapeToStride (t. shape () ), s.size (), 0 );
384+ Shape stride = leftpad (t. stride ( ), s.size (), 0 );
368385 Shape shape = leftpad (t.shape (), s.size (), 0 );
369386 for (size_t i=0 ;i<s.size ();i++) {
370387 if (shape[i] != s[i])
@@ -386,4 +403,4 @@ std::pair<Tensor, Tensor> et::brodcast_tensors(const Tensor& a, const Tensor& b)
386403std::pair<Tensor, Tensor> Tensor::brodcast (const Tensor& other) const
387404{
388405 return brodcast_tensors (*this , other);
389- }
406+ }
0 commit comments