@@ -367,6 +367,47 @@ TYPED_TEST(PgemmTest, odd_case)
367367 this ->compare_result (ncolA_global, ncolB_global, LDC_global);
368368}
369369
370+ TYPED_TEST (PgemmTest, odd_case_not_gather)
371+ {
372+ const int ncolA_global = 17 , ncolB_global = 7 , nrow_global = 13 ;
373+ const int LDA_global = 17 , LDB_global = 18 , LDC_global = 19 ;
374+
375+ this ->decide_ngroup (2 , 2 );
376+ this ->prepare (ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
377+ std::vector<int > colB_loc (this ->nproc_col );
378+ MPI_Allgather (&this ->ncolB , 1 , MPI_INT, colB_loc.data (), 1 , MPI_INT, this ->col_world );
379+ std::vector<int > displs (this ->nproc_col );
380+ displs[0 ] = 0 ;
381+ for (int i = 1 ; i < this ->nproc_col ; i++)
382+ {
383+ displs[i] = (displs[i - 1 ] + colB_loc[i - 1 ]) * LDC_global;
384+ }
385+ int start = displs[this ->rank_col ];
386+
387+ this ->pgemm .set_dimension (this ->col_world ,
388+ this ->row_world ,
389+ this ->ncolA ,
390+ this ->LDA ,
391+ this ->ncolB ,
392+ this ->LDB ,
393+ this ->nrow ,
394+ LDC_global,
395+ false );
396+ this ->pgemm .multiply (this ->alpha , this ->A_local .data (), this ->B_local .data (), this ->beta , this ->C_global .data ()+ start);
397+
398+
399+
400+ for (int i = 0 ; i < this ->ncolB ; i++)
401+ {
402+ for (int j = 0 ; j < ncolA_global; j++)
403+ {
404+ EXPECT_NEAR (get_double (this ->Cref_global [i * LDC_global + start + j]),
405+ get_double (this ->C_global [i * LDC_global + start + j]),
406+ 1e-10 );
407+ }
408+ }
409+ }
410+
370411TYPED_TEST (PgemmTest, row_parallel)
371412{
372413 const int ncolA_global = 17 , ncolB_global = 7 , nrow_global = 13 ;
0 commit comments