@@ -20,6 +20,22 @@ class BlasTest : public testing::Test {
2020
2121TYPED_TEST_SUITE (BlasTest, base::utils::Types);
2222
23+ TYPED_TEST (BlasTest, Copy) {
24+ using Type = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
25+ using Device = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
26+
27+ blas_copy<Type, Device> copyCalculator;
28+
29+ const int n = 3 ;
30+ const Tensor x = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 )}).to_device <Device>());
31+ Tensor y = std::move (Tensor ({static_cast <Type>(0.0 ), static_cast <Type>(0.0 ), static_cast <Type>(0.0 )}).to_device <Device>());
32+
33+ copyCalculator (n, x.data <Type>(), 1 , y.data <Type>(), 1 );
34+ const Tensor expected = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 )}).to_device <Device>());
35+
36+ EXPECT_EQ (y, expected);
37+ }
38+
2339TYPED_TEST (BlasTest, Dot) {
2440 using Type = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
2541 using Device = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
@@ -29,7 +45,7 @@ TYPED_TEST(BlasTest, Dot) {
2945 const int n = 3 ;
3046 const Tensor x = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 )}).to_device <Device>());
3147 const Tensor y = std::move (Tensor ({static_cast <Type>(4.0 ), static_cast <Type>(5.0 ), static_cast <Type>(6.0 )}).to_device <Device>());
32-
48+
3349 Type result = {};
3450 dotCalculator (n, x.data <Type>(), 1 , y.data <Type>(), 1 , &result);
3551 const Type expected = static_cast <Type>(32.0 );
@@ -46,7 +62,7 @@ TYPED_TEST(BlasTest, Scal) {
4662 const int n = 3 ;
4763 const Type alpha = static_cast <Type>(2.0 );
4864 Tensor x = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 )}).to_device <Device>());
49-
65+
5066 scalCalculator (n, &alpha, x.data <Type>(), 1 );
5167 const Tensor expected = std::move (Tensor ({static_cast <Type>(2.0 ), static_cast <Type>(4.0 ), static_cast <Type>(6.0 )}).to_device <Device>());
5268
@@ -64,7 +80,7 @@ TYPED_TEST(BlasTest, Axpy) {
6480 const Type alpha = static_cast <Type>(2.0 );
6581 const Tensor x = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 )}).to_device <Device>());
6682 Tensor y = std::move (Tensor ({static_cast <Type>(4.0 ), static_cast <Type>(5.0 ), static_cast <Type>(6.0 )}).to_device <Device>());
67-
83+
6884 axpyCalculator (n, &alpha, x.data <Type>(), 1 , y.data <Type>(), 1 );
6985 const Tensor expected = std::move (Tensor ({static_cast <Type>(6.0 ), static_cast <Type>(9.0 ), static_cast <Type>(12.0 )}).to_device <Device>());
7086
@@ -83,11 +99,11 @@ TYPED_TEST(BlasTest, Gemv) {
8399 const int n = 2 ;
84100 const Type alpha = static_cast <Type>(2.0 );
85101 const Type beta = static_cast <Type>(3.0 );
86- const Tensor A = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 ),
102+ const Tensor A = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 ),
87103 static_cast <Type>(4.0 ), static_cast <Type>(5.0 ), static_cast <Type>(6.0 )}).to_device <Device>());
88104 const Tensor x = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 )}).to_device <Device>());
89105 Tensor y = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 )}).to_device <Device>());
90-
106+
91107 gemvCalculator (trans, m, n, &alpha, A.data <Type>(), m, x.data <Type>(), 1 , &beta, y.data <Type>(), 1 );
92108 const Tensor expected = std::move (Tensor ({static_cast <Type>(21.0 ), static_cast <Type>(30.0 ), static_cast <Type>(39.0 )}).to_device <Device>());
93109
@@ -114,14 +130,14 @@ TYPED_TEST(BlasTest, GemvBatched) {
114130 std::vector<Type*> y = {};
115131
116132 const Tensor _A = std::move (Tensor ({
117- static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
118- static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
133+ static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
134+ static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
119135 static_cast <Type>(5.0 ), static_cast <Type>(6.0 ),
120-
136+
121137 static_cast <Type>(7.0 ), static_cast <Type>(8.0 ),
122138 static_cast <Type>(9.0 ), static_cast <Type>(10.0 ),
123139 static_cast <Type>(11.0 ),static_cast <Type>(12.0 )}).to_device <Device>());
124-
140+
125141 A.push_back (_A.data <Type>());
126142 A.push_back (_A.data <Type>() + m * n);
127143
@@ -164,14 +180,14 @@ TYPED_TEST(BlasTest, GemvBatchedStrided) {
164180 std::vector<Type*> y = {};
165181
166182 const Tensor _A = std::move (Tensor ({
167- static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
168- static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
183+ static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
184+ static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
169185 static_cast <Type>(5.0 ), static_cast <Type>(6.0 ),
170-
186+
171187 static_cast <Type>(7.0 ), static_cast <Type>(8.0 ),
172188 static_cast <Type>(9.0 ), static_cast <Type>(10.0 ),
173189 static_cast <Type>(11.0 ),static_cast <Type>(12.0 )}).to_device <Device>());
174-
190+
175191 A.push_back (_A.data <Type>());
176192 A.push_back (_A.data <Type>() + m * n);
177193
@@ -205,11 +221,11 @@ TYPED_TEST(BlasTest, Gemm) {
205221 const int n = 2 ;
206222 const Type alpha = static_cast <Type>(2.0 );
207223 const Type beta = static_cast <Type>(3.0 );
208- const Tensor A = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 ),
224+ const Tensor A = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 ),
209225 static_cast <Type>(4.0 ), static_cast <Type>(5.0 ), static_cast <Type>(6.0 )}).to_device <Device>());
210226 const Tensor x = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 )}).to_device <Device>());
211227 Tensor y = std::move (Tensor ({static_cast <Type>(1.0 ), static_cast <Type>(2.0 ), static_cast <Type>(3.0 )}).to_device <Device>());
212-
228+
213229 gemmCalculator (trans, trans, m, 1 , n, &alpha, A.data <Type>(), m, x.data <Type>(), n, &beta, y.data <Type>(), m);
214230 const Tensor expected = std::move (Tensor ({static_cast <Type>(21.0 ), static_cast <Type>(30.0 ), static_cast <Type>(39.0 )}).to_device <Device>());
215231
@@ -237,14 +253,14 @@ TYPED_TEST(BlasTest, GemmBatched) {
237253 std::vector<Type*> y2 = {};
238254
239255 const Tensor _A = std::move (Tensor ({
240- static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
241- static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
256+ static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
257+ static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
242258 static_cast <Type>(5.0 ), static_cast <Type>(6.0 ),
243-
259+
244260 static_cast <Type>(7.0 ), static_cast <Type>(8.0 ),
245261 static_cast <Type>(9.0 ), static_cast <Type>(10.0 ),
246262 static_cast <Type>(11.0 ),static_cast <Type>(12.0 )}).to_device <Device>());
247-
263+
248264 A.push_back (_A.data <Type>());
249265 A.push_back (_A.data <Type>() + m * n);
250266
@@ -287,14 +303,14 @@ TYPED_TEST(BlasTest, GemmBatchedStrided) {
287303 std::vector<Type*> y2 = {};
288304
289305 const Tensor _A = std::move (Tensor ({
290- static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
291- static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
306+ static_cast <Type>(1.0 ), static_cast <Type>(2.0 ),
307+ static_cast <Type>(3.0 ), static_cast <Type>(4.0 ),
292308 static_cast <Type>(5.0 ), static_cast <Type>(6.0 ),
293-
309+
294310 static_cast <Type>(7.0 ), static_cast <Type>(8.0 ),
295311 static_cast <Type>(9.0 ), static_cast <Type>(10.0 ),
296312 static_cast <Type>(11.0 ),static_cast <Type>(12.0 )}).to_device <Device>());
297-
313+
298314 A.push_back (_A.data <Type>());
299315 A.push_back (_A.data <Type>() + m * n);
300316
0 commit comments