|
1 | 1 | #include "../blas_connector.h" |
| 2 | +#include "../module_device/memory_op.h" |
2 | 3 | #include "gtest/gtest.h" |
3 | 4 |
|
4 | 5 | #include <algorithm> |
@@ -84,8 +85,7 @@ TEST(blas_connector, Scal) { |
84 | 85 | }; |
85 | 86 | for (int i = 0; i < size; i++) |
86 | 87 | answer[i] = result[i] * scale; |
87 | | - BlasConnector bs; |
88 | | - bs.scal(size,scale,result,incx); |
| 88 | + BlasConnector::scal(size,scale,result,incx); |
89 | 89 | // incx is the spacing between elements if result |
90 | 90 | for (int i = 0; i < size; i++) { |
91 | 91 | EXPECT_DOUBLE_EQ(answer[i].real(), result[i].real()); |
@@ -313,6 +313,84 @@ TEST(blas_connector, zgemv_) { |
313 | 313 | } |
314 | 314 | } |
315 | 315 |
|
| 316 | +TEST(blas_connector, Gemv) { |
| 317 | + typedef std::complex<double> T; |
| 318 | + const char transa_m = 'N'; |
| 319 | + const char transa_n = 'T'; |
| 320 | + const char transa_h = 'C'; |
| 321 | + const int size_m = 3; |
| 322 | + const int size_n = 4; |
| 323 | + const T alpha_const = {2, 3}; |
| 324 | + const T beta_const = {3, 4}; |
| 325 | + const int lda = 5; |
| 326 | + const int incx = 1; |
| 327 | + const int incy = 1; |
| 328 | + std::array<T, size_m> x_const_m, x_const_c, result_m, answer_m, c_dot_m{}; |
| 329 | + std::array<T, size_n> x_const_n, result_n, result_c, answer_n, answer_c, |
| 330 | + c_dot_n{}, c_dot_c{}; |
| 331 | + std::generate(x_const_n.begin(), x_const_n.end(), []() { |
| 332 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 333 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 334 | + }); |
| 335 | + std::generate(result_n.begin(), result_n.end(), []() { |
| 336 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 337 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 338 | + }); |
| 339 | + std::generate(x_const_m.begin(), x_const_m.end(), []() { |
| 340 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 341 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 342 | + }); |
| 343 | + std::generate(result_m.begin(), result_m.end(), []() { |
| 344 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 345 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 346 | + }); |
| 347 | + std::array<T, size_n * lda> a_const; |
| 348 | + std::generate(a_const.begin(), a_const.end(), []() { |
| 349 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 350 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 351 | + }); |
| 352 | + for (int i = 0; i < size_m; i++) { |
| 353 | + for (int j = 0; j < size_n; j++) { |
| 354 | + c_dot_m[i] += a_const[i + j * lda] * x_const_n[j]; |
| 355 | + } |
| 356 | + answer_m[i] = alpha_const * c_dot_m[i] + beta_const * result_m[i]; |
| 357 | + } |
| 358 | + BlasConnector::gemv(transa_m, size_m, size_n, alpha_const, a_const.data(), lda, |
| 359 | + x_const_n.data(), incx, beta_const, result_m.data(), incy); |
| 360 | + |
| 361 | + for (int j = 0; j < size_n; j++) { |
| 362 | + for (int i = 0; i < size_m; i++) { |
| 363 | + c_dot_n[j] += a_const[i + j * lda] * x_const_m[i]; |
| 364 | + } |
| 365 | + answer_n[j] = alpha_const * c_dot_n[j] + beta_const * result_n[j]; |
| 366 | + } |
| 367 | + BlasConnector::gemv(transa_n, size_m, size_n, alpha_const, a_const.data(), lda, |
| 368 | + x_const_n.data(), incx, beta_const, result_m.data(), incy); |
| 369 | + |
| 370 | + for (int j = 0; j < size_n; j++) { |
| 371 | + for (int i = 0; i < size_m; i++) { |
| 372 | + c_dot_c[j] += conj(a_const[i + j * lda]) * x_const_c[i]; |
| 373 | + } |
| 374 | + answer_c[j] = alpha_const * c_dot_c[j] + beta_const * result_c[j]; |
| 375 | + } |
| 376 | + BlasConnector::gemv(transa_h, size_m, size_n, alpha_const, a_const.data(), lda, |
| 377 | + x_const_n.data(), incx, beta_const, result_m.data(), incy); |
| 378 | + |
| 379 | + for (int i = 0; i < size_m; i++) { |
| 380 | + EXPECT_DOUBLE_EQ(answer_m[i].real(), result_m[i].real()); |
| 381 | + EXPECT_DOUBLE_EQ(answer_m[i].imag(), result_m[i].imag()); |
| 382 | + } |
| 383 | + for (int j = 0; j < size_n; j++) { |
| 384 | + EXPECT_DOUBLE_EQ(answer_n[j].real(), result_n[j].real()); |
| 385 | + EXPECT_DOUBLE_EQ(answer_n[j].imag(), result_n[j].imag()); |
| 386 | + } |
| 387 | + for (int j = 0; j < size_n; j++) { |
| 388 | + EXPECT_DOUBLE_EQ(answer_c[j].real(), result_c[j].real()); |
| 389 | + EXPECT_DOUBLE_EQ(answer_c[j].imag(), result_c[j].imag()); |
| 390 | + } |
| 391 | +} |
| 392 | + |
| 393 | + |
316 | 394 | TEST(blas_connector, dgemm_) { |
317 | 395 | typedef double T; |
318 | 396 | const char transa_m = 'N'; |
@@ -404,6 +482,56 @@ TEST(blas_connector, zgemm_) { |
404 | 482 | } |
405 | 483 | } |
406 | 484 |
|
| 485 | +TEST(blas_connector, Gemm) { |
| 486 | + typedef std::complex<double> T; |
| 487 | + const char transa_m = 'N'; |
| 488 | + const char transb_m = 'N'; |
| 489 | + const int size_m = 3; |
| 490 | + const int size_n = 4; |
| 491 | + const int size_k = 5; |
| 492 | + const T alpha_const = {2, 3}; |
| 493 | + const T beta_const = {3, 4}; |
| 494 | + const int lda = 6; |
| 495 | + const int ldb = 5; |
| 496 | + const int ldc = 4; |
| 497 | + std::array<T, size_k * lda> a_const; |
| 498 | + std::array<T, size_n * ldb> b_const; |
| 499 | + std::array<T, size_n * ldc> c_dot{}, answer, result; |
| 500 | + std::generate(a_const.begin(), a_const.end(), []() { |
| 501 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 502 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 503 | + }); |
| 504 | + std::generate(b_const.begin(), b_const.end(), []() { |
| 505 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 506 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 507 | + }); |
| 508 | + std::generate(result.begin(), result.end(), []() { |
| 509 | + return T{static_cast<double>(std::rand() / double(RAND_MAX)), |
| 510 | + static_cast<double>(std::rand() / double(RAND_MAX))}; |
| 511 | + }); |
| 512 | + for (int i = 0; i < size_m; i++) { |
| 513 | + for (int j = 0; j < size_n; j++) { |
| 514 | + for (int k = 0; k < size_k; k++) { |
| 515 | + c_dot[i + j * ldc] += |
| 516 | + a_const[i + k * lda] * b_const[k + j * ldb]; |
| 517 | + } |
| 518 | + answer[i + j * ldc] = alpha_const * c_dot[i + j * ldc] + |
| 519 | + beta_const * result[i + j * ldc]; |
| 520 | + } |
| 521 | + } |
| 522 | + BlasConnector::gemm(transa_m, transb_m, size_m, size_n, size_k, alpha_const, |
| 523 | + a_const.data(), lda, b_const.data(), ldb, beta_const, |
| 524 | + result.data(), ldc); |
| 525 | + |
| 526 | + for (int i = 0; i < size_m; i++) |
| 527 | + for (int j = 0; j < size_n; j++) { |
| 528 | + EXPECT_DOUBLE_EQ(answer[i + j * ldc].real(), |
| 529 | + result[i + j * ldc].real()); |
| 530 | + EXPECT_DOUBLE_EQ(answer[i + j * ldc].imag(), |
| 531 | + result[i + j * ldc].imag()); |
| 532 | + } |
| 533 | +} |
| 534 | + |
407 | 535 | int main(int argc, char **argv) { |
408 | 536 | testing::InitGoogleTest(&argc, argv); |
409 | 537 | return RUN_ALL_TESTS(); |
|
0 commit comments