Skip to content

Commit c778f26

Browse files
committed
Add blas_connector CPU tests
1 parent 31d4dff commit c778f26

File tree

1 file changed

+130
-2
lines changed

1 file changed

+130
-2
lines changed

source/module_base/test/blas_connector_test.cpp

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "../blas_connector.h"
2+
#include "../module_device/memory_op.h"
23
#include "gtest/gtest.h"
34

45
#include <algorithm>
@@ -84,8 +85,7 @@ TEST(blas_connector, Scal) {
8485
};
8586
for (int i = 0; i < size; i++)
8687
answer[i] = result[i] * scale;
87-
BlasConnector bs;
88-
bs.scal(size,scale,result,incx);
88+
BlasConnector::scal(size,scale,result,incx);
8989
// incx is the spacing between elements if result
9090
for (int i = 0; i < size; i++) {
9191
EXPECT_DOUBLE_EQ(answer[i].real(), result[i].real());
@@ -313,6 +313,84 @@ TEST(blas_connector, zgemv_) {
313313
}
314314
}
315315

316+
TEST(blas_connector, Gemv) {
317+
typedef std::complex<double> T;
318+
const char transa_m = 'N';
319+
const char transa_n = 'T';
320+
const char transa_h = 'C';
321+
const int size_m = 3;
322+
const int size_n = 4;
323+
const T alpha_const = {2, 3};
324+
const T beta_const = {3, 4};
325+
const int lda = 5;
326+
const int incx = 1;
327+
const int incy = 1;
328+
std::array<T, size_m> x_const_m, x_const_c, result_m, answer_m, c_dot_m{};
329+
std::array<T, size_n> x_const_n, result_n, result_c, answer_n, answer_c,
330+
c_dot_n{}, c_dot_c{};
331+
std::generate(x_const_n.begin(), x_const_n.end(), []() {
332+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
333+
static_cast<double>(std::rand() / double(RAND_MAX))};
334+
});
335+
std::generate(result_n.begin(), result_n.end(), []() {
336+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
337+
static_cast<double>(std::rand() / double(RAND_MAX))};
338+
});
339+
std::generate(x_const_m.begin(), x_const_m.end(), []() {
340+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
341+
static_cast<double>(std::rand() / double(RAND_MAX))};
342+
});
343+
std::generate(result_m.begin(), result_m.end(), []() {
344+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
345+
static_cast<double>(std::rand() / double(RAND_MAX))};
346+
});
347+
std::array<T, size_n * lda> a_const;
348+
std::generate(a_const.begin(), a_const.end(), []() {
349+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
350+
static_cast<double>(std::rand() / double(RAND_MAX))};
351+
});
352+
for (int i = 0; i < size_m; i++) {
353+
for (int j = 0; j < size_n; j++) {
354+
c_dot_m[i] += a_const[i + j * lda] * x_const_n[j];
355+
}
356+
answer_m[i] = alpha_const * c_dot_m[i] + beta_const * result_m[i];
357+
}
358+
BlasConnector::gemv(transa_m, size_m, size_n, alpha_const, a_const.data(), lda,
359+
x_const_n.data(), incx, beta_const, result_m.data(), incy);
360+
361+
for (int j = 0; j < size_n; j++) {
362+
for (int i = 0; i < size_m; i++) {
363+
c_dot_n[j] += a_const[i + j * lda] * x_const_m[i];
364+
}
365+
answer_n[j] = alpha_const * c_dot_n[j] + beta_const * result_n[j];
366+
}
367+
BlasConnector::gemv(transa_n, size_m, size_n, alpha_const, a_const.data(), lda,
368+
x_const_n.data(), incx, beta_const, result_m.data(), incy);
369+
370+
for (int j = 0; j < size_n; j++) {
371+
for (int i = 0; i < size_m; i++) {
372+
c_dot_c[j] += conj(a_const[i + j * lda]) * x_const_c[i];
373+
}
374+
answer_c[j] = alpha_const * c_dot_c[j] + beta_const * result_c[j];
375+
}
376+
BlasConnector::gemv(transa_h, size_m, size_n, alpha_const, a_const.data(), lda,
377+
x_const_n.data(), incx, beta_const, result_m.data(), incy);
378+
379+
for (int i = 0; i < size_m; i++) {
380+
EXPECT_DOUBLE_EQ(answer_m[i].real(), result_m[i].real());
381+
EXPECT_DOUBLE_EQ(answer_m[i].imag(), result_m[i].imag());
382+
}
383+
for (int j = 0; j < size_n; j++) {
384+
EXPECT_DOUBLE_EQ(answer_n[j].real(), result_n[j].real());
385+
EXPECT_DOUBLE_EQ(answer_n[j].imag(), result_n[j].imag());
386+
}
387+
for (int j = 0; j < size_n; j++) {
388+
EXPECT_DOUBLE_EQ(answer_c[j].real(), result_c[j].real());
389+
EXPECT_DOUBLE_EQ(answer_c[j].imag(), result_c[j].imag());
390+
}
391+
}
392+
393+
316394
TEST(blas_connector, dgemm_) {
317395
typedef double T;
318396
const char transa_m = 'N';
@@ -404,6 +482,56 @@ TEST(blas_connector, zgemm_) {
404482
}
405483
}
406484

485+
TEST(blas_connector, Gemm) {
486+
typedef std::complex<double> T;
487+
const char transa_m = 'N';
488+
const char transb_m = 'N';
489+
const int size_m = 3;
490+
const int size_n = 4;
491+
const int size_k = 5;
492+
const T alpha_const = {2, 3};
493+
const T beta_const = {3, 4};
494+
const int lda = 6;
495+
const int ldb = 5;
496+
const int ldc = 4;
497+
std::array<T, size_k * lda> a_const;
498+
std::array<T, size_n * ldb> b_const;
499+
std::array<T, size_n * ldc> c_dot{}, answer, result;
500+
std::generate(a_const.begin(), a_const.end(), []() {
501+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
502+
static_cast<double>(std::rand() / double(RAND_MAX))};
503+
});
504+
std::generate(b_const.begin(), b_const.end(), []() {
505+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
506+
static_cast<double>(std::rand() / double(RAND_MAX))};
507+
});
508+
std::generate(result.begin(), result.end(), []() {
509+
return T{static_cast<double>(std::rand() / double(RAND_MAX)),
510+
static_cast<double>(std::rand() / double(RAND_MAX))};
511+
});
512+
for (int i = 0; i < size_m; i++) {
513+
for (int j = 0; j < size_n; j++) {
514+
for (int k = 0; k < size_k; k++) {
515+
c_dot[i + j * ldc] +=
516+
a_const[i + k * lda] * b_const[k + j * ldb];
517+
}
518+
answer[i + j * ldc] = alpha_const * c_dot[i + j * ldc] +
519+
beta_const * result[i + j * ldc];
520+
}
521+
}
522+
BlasConnector::gemm(transa_m, transb_m, size_m, size_n, size_k, alpha_const,
523+
a_const.data(), lda, b_const.data(), ldb, beta_const,
524+
result.data(), ldc);
525+
526+
for (int i = 0; i < size_m; i++)
527+
for (int j = 0; j < size_n; j++) {
528+
EXPECT_DOUBLE_EQ(answer[i + j * ldc].real(),
529+
result[i + j * ldc].real());
530+
EXPECT_DOUBLE_EQ(answer[i + j * ldc].imag(),
531+
result[i + j * ldc].imag());
532+
}
533+
}
534+
407535
int main(int argc, char **argv) {
408536
testing::InitGoogleTest(&argc, argv);
409537
return RUN_ALL_TESTS();

0 commit comments

Comments
 (0)