Skip to content

Commit cf18fd3

Browse files
authored
Merge pull request #113 from marty1885/master
Improve step support in indexing
2 parents 08c5635 + 789713f commit cf18fd3

File tree

7 files changed

+85
-64
lines changed

7 files changed

+85
-64
lines changed

Etaler/3rdparty/half_precision

Etaler/Core/Tensor.cpp

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -176,43 +176,60 @@ Tensor Tensor::view(svector<Range> ranges) const
176176
while(ranges.size() != dimentions())
177177
ranges.push_back(all());
178178

179-
auto resolve_index = [](intmax_t idx, bool from_back, intmax_t size) {
180-
if(from_back == true)
179+
auto resolve_index = [](intmax_t idx, intmax_t size) -> intmax_t {
180+
if(idx < 0)
181181
return size-idx;
182-
else
183-
return idx;
182+
return idx;
184183
};
185184

186-
auto resolve_range_size = [resolve_index](Range r, intmax_t size) {
187-
return resolve_index(r.end(), r.endFromBack(), size) - resolve_index(r.start(), r.startFromBack(), size);
185+
auto is_index_valid = [](intmax_t idx, intmax_t size) -> bool {
186+
if(idx >= 0)
187+
return idx < size;
188+
return -idx <= size;
188189
};
189190

190191
Shape result_shape;
191192
svector<intmax_t> offset;
193+
Shape viewed_strides = pimpl_->stride();
194+
offset.reserve(dimentions());
192195

193-
for(size_t i=0;i<dimentions();i++) {
194-
Range r = ranges[i];
195-
196-
intmax_t start = resolve_index(r.start(), r.startFromBack(), shape()[i]);
197-
intmax_t size = resolve_range_size(r, shape()[i]);
198-
199-
if(size < 0)
200-
throw EtError("Negative steps not supported now");
201-
if(start < 0 || (start+size) > shape()[i])
202-
throw EtError("Indexing from " + std::to_string(start+size-1) + " is out of the range of " + std::to_string(shape()[i]));
196+
assert(viewed_strides.size() == dimentions());
203197

204-
offset.push_back(start);
205-
if(size != 1 || result_shape.size() != 0) //Ignore heading 1 dimentions
198+
for(size_t i=0;i<dimentions();i++) {
199+
const Range& r = ranges[i];
200+
intmax_t dim_size = shape()[i];
201+
202+
intmax_t start = r.start().value_or(0);
203+
intmax_t stop = r.stop().value_or(dim_size);
204+
intmax_t step = r.step().value_or(1);
205+
206+
// Indexing validations
207+
if(step == 0)
208+
throw EtError("Error: Step size is zero in dimension " + std::to_string(i));
209+
if(is_index_valid(start, dim_size) == false)
210+
throw EtError("Starting index " + std::to_string(start) + " is out of range in dimension " + std::to_string(i));
211+
if(is_index_valid(stop, dim_size+1) == false)
212+
throw EtError("Stopping index " + std::to_string(stop) + " is out of range in dimension " + std::to_string(i));
213+
214+
intmax_t real_start = resolve_index(start, dim_size);
215+
intmax_t real_stop = resolve_index(stop, dim_size);
216+
intmax_t size = (real_stop - real_start - 1) / step + 1;
217+
218+
if((real_stop - real_start) * step < 0)
219+
throw EtError("Step is going in the wrong direction. Will cause infinate loop");
220+
viewed_strides[i] *= step;
221+
222+
offset.push_back(real_start);
223+
if(size != 1 || result_shape.empty() == false) //Ignore heading 1 dimentions
206224
result_shape.push_back(size);
207225
}
208226

209227
//If all dims are 1, thus no shape. Give it a shape
210-
if(result_shape.size() == 0)
228+
if(result_shape.empty() == true)
211229
result_shape.push_back(1);
212230

213-
Shape view_meta_strides = pimpl_->stride();
214231
size_t initial_offset = unfold(offset, pimpl_->stride())+pimpl_->offset();
215-
return std::make_shared<TensorImpl>(pimpl_->buffer(), result_shape, view_meta_strides, initial_offset);
232+
return std::make_shared<TensorImpl>(pimpl_->buffer(), result_shape, viewed_strides, initial_offset);
216233
}
217234

218235
Tensor et::zeros(const Shape& shape, DType dtype, Backend* backend)
@@ -364,7 +381,7 @@ inline Shape brodcast_result_shape(Shape a, Shape b)
364381
Tensor et::brodcast_to(const Tensor& t, Shape s)
365382
{
366383
et_assert(s.size() >= t.dimentions());
367-
Shape stride = leftpad(shapeToStride(t.shape()), s.size(), 0);
384+
Shape stride = leftpad(t.stride(), s.size(), 0);
368385
Shape shape = leftpad(t.shape(), s.size(), 0);
369386
for(size_t i=0;i<s.size();i++) {
370387
if(shape[i] != s[i])
@@ -386,4 +403,4 @@ std::pair<Tensor, Tensor> et::brodcast_tensors(const Tensor& a, const Tensor& b)
386403
std::pair<Tensor, Tensor> Tensor::brodcast(const Tensor& other) const
387404
{
388405
return brodcast_tensors(*this, other);
389-
}
406+
}

Etaler/Core/Tensor.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ struct ETALER_EXPORT Tensor
6363
size_t dimentions() const {return pimpl_->dimentions();}
6464
void resize(Shape s) {pimpl()->resize(s);}
6565
bool iscontiguous() const {return pimpl()->iscontiguous();}
66+
Shape stride() const {return pimpl()->stride();}
6667

67-
Backend* backend() const {return pimpl()->backend();};
68+
Backend* backend() const {return pimpl()->backend();}
6869

6970

7071
template <typename ImplType=TensorImpl>

Etaler/Core/Views.hpp

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <variant>
88
#include <memory>
9+
#include <optional>
910

1011
namespace et
1112
{
@@ -14,52 +15,30 @@ struct Range
1415
{
1516
Range() = default;
1617
Range(intmax_t start)
17-
{
18-
start_ = start;
19-
end_ = start+1;
20-
}
18+
: start_(start), stop_(start+1)
19+
{}
2120

22-
Range(intmax_t start, intmax_t end)
23-
{
24-
start_ = start;
25-
end_ = end;
21+
Range(intmax_t start, intmax_t stop)
22+
: start_(start), stop_(stop)
23+
{}
2624

27-
if (start < 0) {
28-
start_from_back_ = true;
29-
end = -end;
30-
}
25+
Range(intmax_t start, intmax_t stop, intmax_t step)
26+
: start_(start), stop_(stop), step_(step)
27+
{}
3128

32-
if (end < 0) {
33-
end_from_back_ = true;
34-
end = -end;
35-
}
36-
}
37-
38-
Range(intmax_t start, intmax_t end, bool start_from_back, bool end_from_back)
39-
{
40-
start_ = start;
41-
et_assert(end >= 0);
42-
end_ = end;
43-
start_from_back_ = start_from_back;
44-
end_from_back_ = end_from_back;
45-
}
46-
47-
intmax_t start() const {return start_;}
48-
intmax_t end() const {return end_;}
49-
bool startFromBack() const {return start_from_back_;}
50-
bool endFromBack() const {return end_from_back_;}
29+
std::optional<intmax_t> start() const {return start_;}
30+
std::optional<intmax_t> stop() const {return stop_;}
31+
std::optional<intmax_t> step() const {return step_;}
5132

5233
protected:
53-
intmax_t start_ = 0;
54-
bool start_from_back_ = false;
55-
intmax_t end_ = 0;
56-
bool end_from_back_ = false;
57-
//intmax_t step_size_ = 1;
34+
std::optional<intmax_t> start_;
35+
std::optional<intmax_t> stop_;
36+
std::optional<intmax_t> step_;
5837
};
5938

6039
inline Range all()
6140
{
62-
return Range(0, 0, false, true);
41+
return Range();
6342
}
6443

6544
inline Range range(intmax_t start, intmax_t end)
@@ -72,5 +51,10 @@ inline Range range(intmax_t end)
7251
return Range(0, end);
7352
}
7453

54+
inline Range range(intmax_t start, intmax_t stop, intmax_t step)
55+
{
56+
return Range(start, stop, step);
57+
}
58+
7559

7660
}

tests/common_tests.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,25 @@ TEST_CASE("Testing Tensor", "[Tensor]")
239239
CHECK(u.isSame(zeros_like(u)));
240240
}
241241

242+
SECTION("Strided views") {
243+
SECTION("read") {
244+
int a[] = {0, 2};
245+
Tensor q = Tensor({2}, a);
246+
Tensor res = t.view({0, range(0, 3, 2)});
247+
CHECK(res.isSame(q));
248+
}
249+
250+
SECTION("write") {
251+
int a[] = {-1, -1};
252+
Tensor q = Tensor({2}, a);
253+
t[{0, range(0, 3, 2)}] = q;
254+
255+
int b[] = {-1, 1, -1, 3};
256+
Tensor r = Tensor({4}, b);
257+
CHECK(t[{0}].isSame(r));
258+
}
259+
}
260+
242261
SECTION("subscription operator") {
243262
svector<Range> r = {range(2)};
244263
//The [] operator should work exactly like the view() function

0 commit comments

Comments
 (0)