Skip to content

Commit 17a4a03

Browse files
authored
Merge pull request #127 from marty1885/apichange
Rework indexing system
2 parents 8d09a41 + 21f9c96 commit 17a4a03

File tree

6 files changed

+66
-33
lines changed

6 files changed

+66
-33
lines changed

Etaler/Algorithms/SpatialPoolerND.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ SpatialPoolerND::SpatialPoolerND(const Shape& input_shape, size_t kernel_size, s
3535
permanences_ = Tensor(output_shape+potential_pool_size, DType::Float, b);
3636

3737
for(size_t i=0;i<(size_t)output_shape.volume();i++) {
38-
svector<Range> write_loc;
38+
IndexList write_loc;
3939
Shape loc = foldIndex(i, output_shape);
4040
for(size_t j=0;j<output_shape.size();j++)
4141
write_loc.push_back(loc[j]);
4242

43-
svector<Range> read_loc(loc.size());
43+
IndexList read_loc(loc.size());
4444
for(size_t j=0;j<loc.size();j++) {
4545
intmax_t pos = loc[j]*stride;
4646
read_loc[j] = range(pos, pos+kernel_size);

Etaler/Backends/CPUBackend.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,8 @@ void CPUBackend::assign(TensorImpl* dest, const TensorImpl* src)
594594
throw EtError("Shape mismatch in tensor assignment. Shape "
595595
+ to_string(dest->shape()) + " and " + to_string(src->shape()));
596596

597-
auto source = realize(src);
598-
599-
if(dest->dtype() != source->dtype())
600-
source = cast(source.get(), dest->dtype());
597+
if(dest->dtype() != src->dtype())
598+
assign(dest, cast(src, dest->dtype()).get());
601599

602600
dispatch(dest->dtype(), [&](auto v) {
603601
using T = decltype(v);

Etaler/Core/Tensor.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,9 @@ bool Tensor::isSame(const Tensor& other) const
168168
return (*this == other).sum().item<int32_t>() == (int32_t)size();
169169
}
170170

171-
Tensor Tensor::view(svector<Range> ranges) const
171+
Tensor Tensor::view(const svector<std::variant<Range, intmax_t, int, size_t, unsigned int>>& rgs) const
172172
{
173+
auto ranges = rgs;
173174
if(ranges.size() > dimentions())
174175
throw EtError("Cannot view a tensor of " + std::to_string(dimentions()) + " with " + std::to_string(ranges.size()) + " dimentions");
175176

@@ -196,16 +197,20 @@ Tensor Tensor::view(svector<Range> ranges) const
196197

197198
assert(viewed_strides.size() == dimentions());
198199

199-
for(size_t i=0;i<dimentions();i++) {
200-
const Range& r = ranges[i];
200+
for(size_t i=0;i<dimentions();i++) { std::visit([&](auto index_range) { // <- make the code neater
201+
const auto& r = index_range;
201202
intmax_t dim_size = shape()[i];
202203

203-
intmax_t start = r.start().value_or(0);
204-
intmax_t stop = r.stop().value_or(dim_size);
204+
// Try to resolve the indexing details
205+
auto [start, stop, step, keep_dim] = [&r, dim_size]() -> std::tuple<intmax_t, intmax_t, intmax_t, bool> {
206+
if constexpr(std::is_same_v<std::decay_t<decltype(r)>, Range> == true)
207+
return {r.start().value_or(0), r.stop().value_or(dim_size), r.step().value_or(1), true};
208+
else // is a integer
209+
return {r, r+1, (r<0?-1:1), false};
210+
}();
205211

206212
intmax_t real_start = resolve_index(start, dim_size);
207213
intmax_t real_stop = resolve_index(stop, dim_size);
208-
intmax_t step = r.step().value_or(real_stop>real_start?1:-1);
209214
intmax_t size = (std::abs(real_stop - real_start) - 1) / std::abs(step) + 1;
210215

211216
// Indexing validations
@@ -221,11 +226,11 @@ Tensor Tensor::view(svector<Range> ranges) const
221226
viewed_strides[i] *= step;
222227

223228
offset.push_back(real_start);
224-
if(size != 1 || result_shape.empty() == false) { //Ignore heading 1 dimentions
229+
if(keep_dim) {
225230
result_shape.push_back(size);
226231
result_stride.push_back(viewed_strides[i]);
227232
}
228-
}
233+
}, ranges[i]); }
229234

230235
//If all dims are 1, thus no shape. Give it a shape
231236
if(result_shape.empty() == true) {
@@ -337,7 +342,7 @@ Tensor et::cat(const svector<Tensor>& tensors, intmax_t dim)
337342
Tensor res = Tensor(res_shape, base_dtype, base_backend);
338343

339344
intmax_t pos = 0;
340-
svector<Range> ranges;
345+
IndexList ranges;
341346
for(size_t i=0;i<res_shape.size();i++)
342347
ranges.push_back(et::all());
343348

Etaler/Core/Tensor.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ struct ETALER_EXPORT TensorIterator
4545
difference_type operator- (const ThisIterator& rhs) const { return offset_ - rhs.offset_; }
4646
ThisIterator operator+(intmax_t n) {return ThisIterator(t_,offset_+n);}
4747
ThisIterator operator-(intmax_t n) {return ThisIterator(t_,offset_-n);}
48+
ThisIterator& operator+=(intmax_t n) { offset_+=n; return *this; }
49+
ThisIterator& operator-=(intmax_t n) { offset_-=n; return *this; }
4850
value_type operator[](intmax_t n) { return *operator+(n); }
4951
bool operator< (const ThisIterator& rhs) const { return offset_ < rhs.offset_; }
5052
bool operator> (const ThisIterator& rhs) const { return offset_ > rhs.offset_; }
@@ -60,6 +62,8 @@ Tensor ETALER_EXPORT brodcast_to(const Tensor& t, Shape s);
6062
ETALER_EXPORT std::ostream& operator<< (std::ostream& os, const Tensor& t);
6163
std::string to_string(const Tensor& t);
6264

65+
using IndexList = svector<std::variant<Range, intmax_t, int, size_t, unsigned int>>;
66+
6367
struct ETALER_EXPORT Tensor
6468
{
6569
Tensor() = default;
@@ -142,7 +146,7 @@ struct ETALER_EXPORT Tensor
142146
Tensor copy() const;
143147

144148
//View/Indexing
145-
Tensor view(svector<Range> ranges) const;
149+
Tensor view(const IndexList& ranges) const;
146150

147151
Tensor reshape(Shape shape) const
148152
{
@@ -220,6 +224,12 @@ struct ETALER_EXPORT Tensor
220224
inline bool any() const { return cast(DType::Bool).sum(std::nullopt, DType::Bool).item<uint8_t>(); }
221225
inline bool all() const { return cast(DType::Bool).sum(std::nullopt).item<int32_t>() == int32_t(size()); }
222226

227+
// Neumeric operations
228+
Tensor operator+= (const Tensor& other) { *this = *this + other; return *this; }
229+
Tensor operator-= (const Tensor& other) { *this = *this - other; return *this; }
230+
Tensor operator*= (const Tensor& other) { *this = *this * other; return *this; }
231+
Tensor operator/= (const Tensor& other) { *this = *this / other; return *this; }
232+
223233
Tensor operator- () const {return negate();}
224234
Tensor operator+ () const {return *this;}
225235
Tensor operator! () const {return logical_not();}
@@ -241,7 +251,7 @@ struct ETALER_EXPORT Tensor
241251
Tensor operator!= (const Tensor& other) const {return !equal(other);}
242252

243253
//Subscription operator
244-
Tensor operator [] (svector<Range> r) { return view(r); }
254+
Tensor operator [] (const IndexList& r) { return view(r); }
245255

246256
Tensor sum(std::optional<intmax_t> dim=std::nullopt, DType dtype=DType::Unknown) const;
247257
Tensor abs() const { return backend()->abs(pimpl()); }

Etaler/Core/Views.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ namespace et
1414
struct Range
1515
{
1616
Range() = default;
17-
Range(intmax_t start)
18-
: start_(start), stop_(start+(start>=0?1:-1))
19-
{}
20-
2117
Range(std::optional<intmax_t> start, intmax_t stop
2218
, std::optional<intmax_t> step = std::nullopt)
2319
: start_(start), stop_(stop), step_(step)

tests/common_tests.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <Etaler/Algorithms/SDRClassifer.hpp>
1010

1111
#include <numeric>
12+
#include <execution>
1213

1314
using namespace et;
1415

@@ -219,6 +220,14 @@ TEST_CASE("Testing Tensor", "[Tensor]")
219220
CHECK(realize(r).isSame(pred));
220221
}
221222

223+
SECTION("Indexing a column") {
224+
Tensor q = t.view({all(), 0});
225+
int arr[] = {0, 4, 8, 12};
226+
Tensor r = Tensor({4}, arr);
227+
CHECK(q.shape() == Shape({4}));
228+
CHECK(r.isSame(q));
229+
}
230+
222231
SECTION("Indexing with negative values") {
223232
Tensor q = t.view({3});
224233
Tensor r;
@@ -285,7 +294,7 @@ TEST_CASE("Testing Tensor", "[Tensor]")
285294
}
286295

287296
SECTION("subscription operator") {
288-
svector<Range> r = {range(2)};
297+
IndexList r = {range(2)};
289298
//The [] operator should work exactly like the view() function
290299
CHECK(t[r].isSame(t.view(r)));
291300
}
@@ -938,6 +947,12 @@ TEST_CASE("Type system")
938947
// This test checks all components of Tensor works together properly
939948
TEST_CASE("Complex Tensor operations")
940949
{
950+
std::vector<int> v1 = {1, 8, 6, 7
951+
, 3, 2, 5, 6
952+
, 4, 3, 2, 7
953+
, 9, 0 ,1, 1};
954+
Tensor a = Tensor(v1).reshape({4,4});
955+
941956
SECTION("Vector inner product") {
942957
std::vector<int> v1 = {1, 6, 7, 9, 15, 6};
943958
std::vector<int> v2 = {3, 7, 8, -1, 6, 15};
@@ -948,25 +963,23 @@ TEST_CASE("Complex Tensor operations")
948963
CHECK((a*b).sum().item<int>() == std::inner_product(v1.begin(), v1.end(), v2.begin(), 0));
949964
}
950965

966+
SECTION("assign column to row") {
967+
std::vector<int> v2 = {9, 0, 1, 1};
968+
Tensor b = Tensor(v2);
969+
970+
Tensor t = a.copy();
971+
t[{all(), 1}] = a[{3}];
972+
CHECK(t[{all(), 1}].isSame(b));
973+
}
974+
951975
SECTION("shuffle") {
952976
std::mt19937 rng;
953-
std::vector<int> v1 = {1, 8, 6, 7
954-
, 3, 2, 5, 6
955-
, 4, 3, 2, 7
956-
, 9, 0 ,1, 1};
957-
Tensor a = Tensor(v1).reshape({4,4});
958977
std::shuffle(a.begin(), a.end(), rng);
959978
CHECK(std::accumulate(v1.begin(), v1.end(), 0) == a.sum().item<int>());
960979
}
961980

962981
SECTION("find_if") {
963-
std::vector<int> v1 = {1, 8, 6, 7
964-
, 3, 2, 5, 6
965-
, 4, 3, 2, 7
966-
, 9, 0 ,1, 1};
967-
Tensor a = Tensor(v1).reshape({4,4});
968982
Tensor b = a[{0}];
969-
970983
CHECK(std::find_if(a.begin(), a.end(), [&b](auto t){ return t.isSame(b); }) != a.end());
971984
}
972985

@@ -976,6 +989,17 @@ TEST_CASE("Complex Tensor operations")
976989
std::transform(a.begin(), a.end(), b.begin(), [](const auto& t){return zeros_like(t);});
977990
CHECK(b.isSame(zeros_like(a)));
978991
}
992+
993+
SECTION("accumulate") {
994+
// Test summing along the first dimension. Making sure iterator and sum() works
995+
// Tho you should always use the sum() function instead of accumulate or reduce
996+
Tensor t = std::accumulate(a.begin(), a.end(), zeros({a.shape()[1]}));
997+
Tensor q = std::reduce(std::execution::par, a.begin(), a.end(), zeros({a.shape()[1]}));
998+
Tensor a_sum = a.sum(0);
999+
CHECK(t.isSame(a_sum));
1000+
CHECK(q.isSame(a_sum));
1001+
CHECK(t.isSame(q)); // Should be communicative
1002+
}
9791003
}
9801004

9811005
// TEST_CASE("Serealize")

0 commit comments

Comments
 (0)