Skip to content

Commit 773aa55

Browse files
committed
add test
1 parent bc7b92c commit 773aa55

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

source/module_base/para_gemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ void PGemmCN<T, Device>::set_dimension(
4646
this->nrow = nrow_in;
4747
#ifdef __MPI
4848
this->gatherC = gatherC_in;
49+
requests.resize(col_nproc);
4950
colA_loc.resize(col_nproc);
5051
MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world);
5152
for (int ip = 0; ip < col_nproc; ip++)
@@ -58,7 +59,6 @@ void PGemmCN<T, Device>::set_dimension(
5859
colB_loc.resize(col_nproc);
5960
recv_counts.resize(col_nproc);
6061
displs.resize(col_nproc);
61-
requests.resize(col_nproc);
6262
MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world);
6363
for (int ip = 0; ip < col_nproc; ip++)
6464
{

source/module_base/test_parallel/test_para_gemm.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,47 @@ TYPED_TEST(PgemmTest, odd_case)
367367
this->compare_result(ncolA_global, ncolB_global, LDC_global);
368368
}
369369

370+
TYPED_TEST(PgemmTest, odd_case_not_gather)
371+
{
372+
const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13;
373+
const int LDA_global = 17, LDB_global = 18, LDC_global = 19;
374+
375+
this->decide_ngroup(2, 2);
376+
this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
377+
std::vector<int> colB_loc(this->nproc_col);
378+
MPI_Allgather(&this->ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, this->col_world);
379+
std::vector<int> displs(this->nproc_col);
380+
displs[0] = 0;
381+
for (int i = 1; i < this->nproc_col; i++)
382+
{
383+
displs[i] = (displs[i - 1] + colB_loc[i - 1]) * LDC_global;
384+
}
385+
int start = displs[this->rank_col];
386+
387+
this->pgemm.set_dimension(this->col_world,
388+
this->row_world,
389+
this->ncolA,
390+
this->LDA,
391+
this->ncolB,
392+
this->LDB,
393+
this->nrow,
394+
LDC_global,
395+
false);
396+
this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()+ start);
397+
398+
399+
400+
for (int i = 0; i < this->ncolB; i++)
401+
{
402+
for (int j = 0; j < ncolA_global; j++)
403+
{
404+
EXPECT_NEAR(get_double(this->Cref_global[i * LDC_global + start + j]),
405+
get_double(this->C_global[i * LDC_global + start + j]),
406+
1e-10);
407+
}
408+
}
409+
}
410+
370411
TYPED_TEST(PgemmTest, row_parallel)
371412
{
372413
const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13;

0 commit comments

Comments
 (0)