@@ -367,7 +367,49 @@ TYPED_TEST(PgemmTest, odd_case)
367367 this ->compare_result (ncolA_global, ncolB_global, LDC_global);
368368}
369369
370- TYPED_TEST (PgemmTest, odd_case_not_gather)
370+ TYPED_TEST (PgemmTest, row_parallel)
371+ {
372+ const int ncolA_global = 17 , ncolB_global = 7 , nrow_global = 13 ;
373+ const int LDA_global = 17 , LDB_global = 18 , LDC_global = 19 ;
374+
375+ this ->decide_ngroup (1 , 4 );
376+ this ->prepare (ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
377+
378+ this ->pgemm .set_dimension (this ->col_world ,
379+ this ->row_world ,
380+ this ->ncolA ,
381+ this ->LDA ,
382+ this ->ncolB ,
383+ this ->LDB ,
384+ this ->nrow ,
385+ LDC_global);
386+ this ->pgemm .multiply (this ->alpha , this ->A_local .data (), this ->B_local .data (), this ->beta , this ->C_global .data ());
387+
388+ this ->compare_result (ncolA_global, ncolB_global, LDC_global);
389+ }
390+
391+ TYPED_TEST (PgemmTest, col_parallel)
392+ {
393+ const int ncolA_global = 17 , ncolB_global = 7 , nrow_global = 13 ;
394+ const int LDA_global = 17 , LDB_global = 18 , LDC_global = 19 ;
395+
396+ this ->decide_ngroup (4 , 1 );
397+ this ->prepare (ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
398+
399+ this ->pgemm .set_dimension (this ->col_world ,
400+ this ->row_world ,
401+ this ->ncolA ,
402+ this ->LDA ,
403+ this ->ncolB ,
404+ this ->LDB ,
405+ this ->nrow ,
406+ LDC_global);
407+ this ->pgemm .multiply (this ->alpha , this ->A_local .data (), this ->B_local .data (), this ->beta , this ->C_global .data ());
408+
409+ this ->compare_result (ncolA_global, ncolB_global, LDC_global);
410+ }
411+
412+ TYPED_TEST (PgemmTest, divide_col)
371413{
372414 const int ncolA_global = 17 , ncolB_global = 7 , nrow_global = 13 ;
373415 const int LDA_global = 17 , LDB_global = 18 , LDC_global = 19 ;
@@ -392,7 +434,7 @@ TYPED_TEST(PgemmTest, odd_case_not_gather)
392434 this ->LDB ,
393435 this ->nrow ,
394436 LDC_global,
395- false );
437+ 2 );
396438 this ->pgemm .multiply (this ->alpha , this ->A_local .data (), this ->B_local .data (), this ->beta , this ->C_global .data ()+ start);
397439
398440
@@ -408,34 +450,32 @@ TYPED_TEST(PgemmTest, odd_case_not_gather)
408450 }
409451}
410452
411- TYPED_TEST (PgemmTest, row_parallel )
453+ TYPED_TEST (PgemmTest, divide_row )
412454{
413455 const int ncolA_global = 17 , ncolB_global = 7 , nrow_global = 13 ;
414456 const int LDA_global = 17 , LDB_global = 18 , LDC_global = 19 ;
415457
416- this ->decide_ngroup (1 , 4 );
458+ this ->decide_ngroup (2 , 2 );
417459 this ->prepare (ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
460+ std::vector<int > colA_loc (this ->nproc_col );
461+ MPI_Allgather (&this ->ncolA , 1 , MPI_INT, colA_loc.data (), 1 , MPI_INT, this ->col_world );
462+ std::vector<int > displs (this ->nproc_col );
463+ displs[0 ] = 0 ;
464+ for (int i = 1 ; i < this ->nproc_col ; i++)
465+ {
466+ displs[i] = (displs[i - 1 ] + colA_loc[i - 1 ]);
467+ }
468+ int start = displs[this ->rank_col ];
418469
419- this ->pgemm .set_dimension (this ->col_world ,
420- this ->row_world ,
421- this ->ncolA ,
422- this ->LDA ,
423- this ->ncolB ,
424- this ->LDB ,
425- this ->nrow ,
426- LDC_global);
427- this ->pgemm .multiply (this ->alpha , this ->A_local .data (), this ->B_local .data (), this ->beta , this ->C_global .data ());
428-
429- this ->compare_result (ncolA_global, ncolB_global, LDC_global);
430- }
431-
432- TYPED_TEST (PgemmTest, col_parallel)
433- {
434- const int ncolA_global = 17 , ncolB_global = 7 , nrow_global = 13 ;
435- const int LDA_global = 17 , LDB_global = 18 , LDC_global = 19 ;
436-
437- this ->decide_ngroup (4 , 1 );
438- this ->prepare (ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global);
470+ int LDC_local = this ->ncolA + 2 ;
471+ std::vector<TypeParam> C_loc (LDC_local * ncolB_global, 0.0 );
472+ for (int i = 0 ; i < ncolB_global; i++)
473+ {
474+ for (int j = 0 ; j < this ->ncolA ; j++)
475+ {
476+ C_loc[i * LDC_local + j] = this ->C_global [i * LDC_global + start + j];
477+ }
478+ }
439479
440480 this ->pgemm .set_dimension (this ->col_world ,
441481 this ->row_world ,
@@ -444,10 +484,21 @@ TYPED_TEST(PgemmTest, col_parallel)
444484 this ->ncolB ,
445485 this ->LDB ,
446486 this ->nrow ,
447- LDC_global);
448- this ->pgemm .multiply (this ->alpha , this ->A_local .data (), this ->B_local .data (), this ->beta , this ->C_global .data ());
487+ LDC_local,
488+ 3 );
489+ this ->pgemm .multiply (this ->alpha , this ->A_local .data (), this ->B_local .data (), this ->beta , C_loc.data ());
449490
450- this ->compare_result (ncolA_global, ncolB_global, LDC_global);
491+
492+
493+ for (int i = 0 ; i < ncolB_global; i++)
494+ {
495+ for (int j = 0 ; j < this ->ncolA ; j++)
496+ {
497+ EXPECT_NEAR (get_double (this ->Cref_global [i * LDC_global + start + j]),
498+ get_double (C_loc[i * LDC_local + j]),
499+ 1e-10 );
500+ }
501+ }
451502}
452503
453504int main (int argc, char ** argv)
0 commit comments