@@ -103,6 +103,13 @@ TEST_CASE("Testing Tensor", "[Tensor]")
103103 CHECK (r.item <uint8_t >() == true );
104104 }
105105
106+ SECTION (" Create Tensor from vector" ) {
107+ std::vector<int > v = {1 , 2 , 3 , 4 };
108+ Tensor t = Tensor (v);
109+ CHECK (t.size () == intmax_t (v.size ()));
110+ CHECK (t.dtype () == DType::Int);
111+ }
112+
106113 SECTION (" tensor like" ) {
107114 Tensor t = ones ({4 ,4 });
108115 Tensor q = ones_like (t);
@@ -293,6 +300,17 @@ TEST_CASE("Testing Tensor", "[Tensor]")
293300 // Check a subset of weather the result is correct
294301 CHECK (t[{2 , 2 }].item <int >() == 11 );
295302 }
303+
304+ SECTION (" all/any test" ) {
305+ CHECK (ones ({3 }).any () == true );
306+ CHECK (zeros ({3 }).any () == false );
307+ CHECK (ones ({3 }).all () == true );
308+ CHECK (zeros ({3 }).all () == false );
309+ CHECK ((ones ({7 }) == zeros ({7 })).all () == false );
310+ CHECK ((ones ({7 }) == zeros ({7 })).any () == false );
311+ CHECK ((ones ({4 ,4 }) == t).any () == true );
312+ CHECK ((ones ({4 ,4 }) == t).all () == false );
313+ }
296314 }
297315
298316 SECTION (" item" ) {
@@ -311,8 +329,8 @@ TEST_CASE("Testing Tensor", "[Tensor]")
311329 Tensor q = zeros ({3 , 4 });
312330 STATIC_REQUIRE (std::is_same_v<Tensor::iterator::value_type, Tensor>);
313331
314- // Tensor::iterator should be bideractional
315- // Reference: http://www.cplusplus.com/reference/iterator/BidirectionalIterator /
332+ // Tensor::iterator should be ramdp,
333+ // Reference: http://www.cplusplus.com/reference/iterator/RandomAccessIterator /
316334 STATIC_REQUIRE (std::is_default_constructible_v<Tensor::iterator>);
317335 STATIC_REQUIRE (std::is_copy_constructible_v<Tensor::iterator>);
318336 STATIC_REQUIRE (std::is_copy_assignable_v<Tensor::iterator>);
@@ -321,6 +339,8 @@ TEST_CASE("Testing Tensor", "[Tensor]")
321339 CHECK (t.begin () == t.begin ());
322340 CHECK ((*t.begin ()).shape () == Shape{4 });
323341 CHECK (t.begin ()->shape () == Shape{4 });
342+ CHECK (t.end () - t.begin () == t.shape ()[0 ]);
343+ CHECK (t.begin ()[2 ].isSame (*t.back ()) == true );
324344 auto it1 = t.begin (), it2 = t.begin ();
325345 it1++;
326346 ++it2;
@@ -341,6 +361,36 @@ TEST_CASE("Testing Tensor", "[Tensor]")
341361 CHECK (num_iteration == t.shape ()[0 ]);
342362 CHECK (t.sum ().item <int >() == 42 *t.size ());
343363 }
364+
365+ SECTION (" swapping Tensor" ) {
366+ std::vector<int > v1 = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 };
367+ std::vector<int > v2 = {11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 };
368+ Tensor t = Tensor (v1).reshape ({3 , 4 });
369+ Tensor q = Tensor (v2).reshape ({3 , 4 });
370+ Tensor t_old = t.copy ();
371+ Tensor q_old = q.copy ();
372+
373+ SECTION (" swap" ) {
374+ swap (t, q);
375+ CHECK (t.isSame (q_old));
376+ CHECK (q.isSame (t_old));
377+ }
378+
379+ SECTION (" swapping views" ) {
380+ swap (q[{1 }], t[{2 }]);
381+ CHECK (t[{2 }].isSame (q_old[{1 }]));
382+ CHECK (q[{1 }].isSame (t_old[{2 }]));
383+ }
384+
385+ SECTION (" swaping itself" ) {
386+ swap (t[{0 }], t[{0 }]);
387+ REQUIRE (t[{0 }].isSame (t_old[{0 }]));
388+
389+ swap (t[{0 }], t[{1 }]);
390+ REQUIRE (t[{0 }].isSame (t_old[{1 }]));
391+ REQUIRE (t[{1 }].isSame (t_old[{0 }]));
392+ }
393+ }
344394}
345395
346396TEST_CASE (" Testing Encoders" , " [Encoder]" )
@@ -768,6 +818,14 @@ TEST_CASE("Type system")
768818 return true ;
769819 }();
770820
821+ SECTION (" abs" ) {
822+ CHECK (abs (ones ({1 }, DType::Bool)).dtype () == DType::Int32);
823+ CHECK (abs (ones ({1 }, DType::Int32)).dtype () == DType::Int32);
824+ CHECK (abs (ones ({1 }, DType::Float)).dtype () == DType::Float);
825+ if (support_fp16)
826+ CHECK (abs (ones ({1 }, DType::Half)).dtype () == DType::Half);
827+ }
828+
771829 SECTION (" exp" ) {
772830 CHECK (exp (ones ({1 }, DType::Bool)).dtype () == DType::Float);
773831 CHECK (exp (ones ({1 }, DType::Int32)).dtype () == DType::Float);
@@ -876,6 +934,50 @@ TEST_CASE("Type system")
876934 }
877935}
878936
937+ // TODO: Should I count this as an integration test?
938+ // This test checks all components of Tensor works together properly
939+ TEST_CASE (" Complex Tensor operations" )
940+ {
941+ SECTION (" Vector inner product" ) {
942+ std::vector<int > v1 = {1 , 6 , 7 , 9 , 15 , 6 };
943+ std::vector<int > v2 = {3 , 7 , 8 , -1 , 6 , 15 };
944+ REQUIRE (v1.size () == v2.size ());
945+ Tensor a = Tensor (v1);
946+ Tensor b = Tensor (v2);
947+
948+ CHECK ((a*b).sum ().item <int >() == std::inner_product (v1.begin (), v1.end (), v2.begin (), 0 ));
949+ }
950+
951+ SECTION (" shuffle" ) {
952+ std::mt19937 rng;
953+ std::vector<int > v1 = {1 , 8 , 6 , 7
954+ , 3 , 2 , 5 , 6
955+ , 4 , 3 , 2 , 7
956+ , 9 , 0 ,1 , 1 };
957+ Tensor a = Tensor (v1).reshape ({4 ,4 });
958+ std::shuffle (a.begin (), a.end (), rng);
959+ CHECK (std::accumulate (v1.begin (), v1.end (), 0 ) == a.sum ().item <int >());
960+ }
961+
962+ SECTION (" find_if" ) {
963+ std::vector<int > v1 = {1 , 8 , 6 , 7
964+ , 3 , 2 , 5 , 6
965+ , 4 , 3 , 2 , 7
966+ , 9 , 0 ,1 , 1 };
967+ Tensor a = Tensor (v1).reshape ({4 ,4 });
968+ Tensor b = a[{0 }];
969+
970+ CHECK (std::find_if (a.begin (), a.end (), [&b](auto t){ return t.isSame (b); }) != a.end ());
971+ }
972+
973+ SECTION (" transform" ) {
974+ Tensor a = ones ({12 , 6 });
975+ Tensor b = ones ({12 , 6 });
976+ std::transform (a.begin (), a.end (), b.begin (), [](const auto & t){return zeros_like (t);});
977+ CHECK (b.isSame (zeros_like (a)));
978+ }
979+ }
980+
879981// TEST_CASE("Serealize")
880982// {
881983// using namespace et;
0 commit comments