99#include < Etaler/Algorithms/SDRClassifer.hpp>
1010
1111#include < numeric>
12+ #include < execution>
1213
1314using namespace et ;
1415
@@ -219,6 +220,14 @@ TEST_CASE("Testing Tensor", "[Tensor]")
219220 CHECK (realize (r).isSame (pred));
220221 }
221222
223+ SECTION (" Indexing a column" ) {
224+ Tensor q = t.view ({all (), 0 });
225+ int arr[] = {0 , 4 , 8 , 12 };
226+ Tensor r = Tensor ({4 }, arr);
227+ CHECK (q.shape () == Shape ({4 }));
228+ CHECK (r.isSame (q));
229+ }
230+
222231 SECTION (" Indexing with negative values" ) {
223232 Tensor q = t.view ({3 });
224233 Tensor r;
@@ -285,7 +294,7 @@ TEST_CASE("Testing Tensor", "[Tensor]")
285294 }
286295
287296 SECTION (" subscription operator" ) {
288- svector<Range> r = {range (2 )};
297+ IndexList r = {range (2 )};
289298 // The [] operator should work exactly like the view() function
290299 CHECK (t[r].isSame (t.view (r)));
291300 }
@@ -938,6 +947,12 @@ TEST_CASE("Type system")
938947// This test checks all components of Tensor works together properly
939948TEST_CASE (" Complex Tensor operations" )
940949{
950+ std::vector<int > v1 = {1 , 8 , 6 , 7
951+ , 3 , 2 , 5 , 6
952+ , 4 , 3 , 2 , 7
953+ , 9 , 0 ,1 , 1 };
954+ Tensor a = Tensor (v1).reshape ({4 ,4 });
955+
941956 SECTION (" Vector inner product" ) {
942957 std::vector<int > v1 = {1 , 6 , 7 , 9 , 15 , 6 };
943958 std::vector<int > v2 = {3 , 7 , 8 , -1 , 6 , 15 };
@@ -948,25 +963,23 @@ TEST_CASE("Complex Tensor operations")
948963 CHECK ((a*b).sum ().item <int >() == std::inner_product (v1.begin (), v1.end (), v2.begin (), 0 ));
949964 }
950965
966+ SECTION (" assign column to row" ) {
967+ std::vector<int > v2 = {9 , 0 , 1 , 1 };
968+ Tensor b = Tensor (v2);
969+
970+ Tensor t = a.copy ();
971+ t[{all (), 1 }] = a[{3 }];
972+ CHECK (t[{all (), 1 }].isSame (b));
973+ }
974+
951975 SECTION (" shuffle" ) {
952976 std::mt19937 rng;
953- std::vector<int > v1 = {1 , 8 , 6 , 7
954- , 3 , 2 , 5 , 6
955- , 4 , 3 , 2 , 7
956- , 9 , 0 ,1 , 1 };
957- Tensor a = Tensor (v1).reshape ({4 ,4 });
958977 std::shuffle (a.begin (), a.end (), rng);
959978 CHECK (std::accumulate (v1.begin (), v1.end (), 0 ) == a.sum ().item <int >());
960979 }
961980
962981 SECTION (" find_if" ) {
963- std::vector<int > v1 = {1 , 8 , 6 , 7
964- , 3 , 2 , 5 , 6
965- , 4 , 3 , 2 , 7
966- , 9 , 0 ,1 , 1 };
967- Tensor a = Tensor (v1).reshape ({4 ,4 });
968982 Tensor b = a[{0 }];
969-
970983 CHECK (std::find_if (a.begin (), a.end (), [&b](auto t){ return t.isSame (b); }) != a.end ());
971984 }
972985
@@ -976,6 +989,17 @@ TEST_CASE("Complex Tensor operations")
976989 std::transform (a.begin (), a.end (), b.begin (), [](const auto & t){return zeros_like (t);});
977990 CHECK (b.isSame (zeros_like (a)));
978991 }
992+
993+ SECTION (" accumulate" ) {
994+ // Test summing along the first dimension. Making sure iterator and sum() works
995+ // Tho you should always use the sum() function instead of accumulate or reduce
996+ Tensor t = std::accumulate (a.begin (), a.end (), zeros ({a.shape ()[1 ]}));
997+ Tensor q = std::reduce (std::execution::par, a.begin (), a.end (), zeros ({a.shape ()[1 ]}));
998+ Tensor a_sum = a.sum (0 );
999+ CHECK (t.isSame (a_sum));
1000+ CHECK (q.isSame (a_sum));
1001+ CHECK (t.isSame (q)); // Should be communicative
1002+ }
9791003}
9801004
9811005// TEST_CASE("Serealize")
0 commit comments