@@ -85,6 +85,10 @@ int MPIR_Group_create(int nproc, MPIR_Group ** new_group_ptr)
8585 return mpi_errno ;
8686}
8787
88+ static bool check_map_is_strided (int size , MPIR_Lpid * map ,
89+ MPIR_Lpid * offset_out , MPIR_Lpid * stride_out ,
90+ MPIR_Lpid * blocksize_out );
91+
8892int MPIR_Group_create_map (int size , int rank , MPIR_Session * session_ptr , MPIR_Lpid * map ,
8993 MPIR_Group * * new_group_ptr )
9094{
@@ -104,10 +108,16 @@ int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_L
104108 newgrp -> rank = rank ;
105109 MPIR_Group_set_session_ptr (newgrp , session_ptr );
106110
107- newgrp -> pmap .use_map = true;
108- newgrp -> pmap .u .map = map ;
111+ if (check_map_is_strided (size , map , & newgrp -> pmap .u .stride .offset ,
112+ & newgrp -> pmap .u .stride .stride , & newgrp -> pmap .u .stride .blocksize )) {
113+ newgrp -> pmap .use_map = false;
114+ MPL_free (map );
115+ } else {
116+ newgrp -> pmap .use_map = true;
117+ newgrp -> pmap .u .map = map ;
118+ /* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
119+ }
109120
110- /* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
111121 * new_group_ptr = newgrp ;
112122 }
113123
@@ -155,50 +165,8 @@ int MPIR_Group_create_stride(int size, int rank, MPIR_Session * session_ptr,
155165 goto fn_exit ;
156166}
157167
158- static MPIR_Lpid pmap_rank_to_lpid (struct MPIR_Pmap * pmap , int rank )
159- {
160- if (rank < 0 || rank >= pmap -> size ) {
161- return MPI_UNDEFINED ;
162- }
163-
164- if (pmap -> use_map ) {
165- return pmap -> u .map [rank ];
166- } else {
167- MPIR_Lpid i_blk = rank / pmap -> u .stride .blocksize ;
168- MPIR_Lpid r_blk = rank % pmap -> u .stride .blocksize ;
169- return pmap -> u .stride .offset + i_blk * pmap -> u .stride .stride + r_blk ;
170- }
171- }
172-
173- static int pmap_lpid_to_rank (struct MPIR_Pmap * pmap , MPIR_Lpid lpid )
174- {
175- if (pmap -> use_map ) {
176- /* Use linear search for now.
177- * Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
178- */
179- for (int rank = 0 ; rank < pmap -> size ; rank ++ ) {
180- if (pmap -> u .map [rank ] == lpid ) {
181- return rank ;
182- }
183- }
184- return MPI_UNDEFINED ;
185- } else {
186- lpid -= pmap -> u .stride .offset ;
187- MPIR_Lpid i_blk = lpid / pmap -> u .stride .stride ;
188- MPIR_Lpid r_blk = lpid % pmap -> u .stride .stride ;
189-
190- if (r_blk >= pmap -> u .stride .blocksize ) {
191- return MPI_UNDEFINED ;
192- }
193-
194- int rank = i_blk * pmap -> u .stride .blocksize + r_blk ;
195- if (rank >= 0 && rank < pmap -> size ) {
196- return rank ;
197- } else {
198- return MPI_UNDEFINED ;
199- }
200- }
201- }
168+ static int pmap_lpid_to_rank (struct MPIR_Pmap * pmap , MPIR_Lpid lpid );
169+ static MPIR_Lpid pmap_rank_to_lpid (struct MPIR_Pmap * pmap , int rank );
202170
203171int MPIR_Group_lpid_to_rank (MPIR_Group * group , MPIR_Lpid lpid )
204172{
@@ -397,3 +365,112 @@ void MPIR_Group_set_session_ptr(MPIR_Group * group_ptr, MPIR_Session * session_p
397365 MPIR_Session_add_ref (session_ptr );
398366 }
399367}
368+
369+ /* internal static routines */
370+
371+ static bool check_map_is_strided (int size , MPIR_Lpid * map ,
372+ MPIR_Lpid * offset_out , MPIR_Lpid * stride_out ,
373+ MPIR_Lpid * blocksize_out )
374+ {
375+ MPIR_Assert (size > 0 );
376+ if (size == 1 ) {
377+ * offset_out = map [0 ];
378+ * stride_out = 1 ;
379+ * blocksize_out = 1 ;
380+ return true;
381+ } else {
382+ MPIR_Lpid offset , stride , blocksize ;
383+ offset = map [0 ];
384+
385+ blocksize = 1 ;
386+ for (int i = 1 ; i < size ; i ++ ) {
387+ if (map [i ] - map [i - 1 ] == 1 ) {
388+ blocksize ++ ;
389+ } else {
390+ break ;
391+ }
392+ }
393+ if (blocksize == size ) {
394+ /* consecutive */
395+ * offset_out = offset ;
396+ * stride_out = 1 ;
397+ * blocksize_out = 1 ;
398+ return true;
399+ } else {
400+ /* NOTE: stride may be negative */
401+ stride = map [blocksize ] - map [0 ];
402+ int n_strides = (size + blocksize - 1 ) / blocksize ;
403+ int k = 0 ;
404+ for (int i = 0 ; i < n_strides ; i ++ ) {
405+ for (int j = 0 ; j < blocksize ; j ++ ) {
406+ if (map [k ] != offset + i * stride + j ) {
407+ return false;
408+ }
409+ k ++ ;
410+ if (k == size ) {
411+ break ;
412+ }
413+ }
414+ }
415+ * offset_out = offset ;
416+ * stride_out = stride ;
417+ * blocksize_out = blocksize ;
418+ return true;
419+ }
420+ }
421+ }
422+
423+ static MPIR_Lpid pmap_rank_to_lpid (struct MPIR_Pmap * pmap , int rank )
424+ {
425+ if (rank < 0 || rank >= pmap -> size ) {
426+ return MPI_UNDEFINED ;
427+ }
428+
429+ if (pmap -> use_map ) {
430+ return pmap -> u .map [rank ];
431+ } else {
432+ MPIR_Lpid i_blk = rank / pmap -> u .stride .blocksize ;
433+ MPIR_Lpid r_blk = rank % pmap -> u .stride .blocksize ;
434+ return pmap -> u .stride .offset + i_blk * pmap -> u .stride .stride + r_blk ;
435+ }
436+ }
437+
438+ static int pmap_lpid_to_rank (struct MPIR_Pmap * pmap , MPIR_Lpid lpid )
439+ {
440+ if (pmap -> use_map ) {
441+ /* Use linear search for now.
442+ * Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
443+ */
444+ for (int rank = 0 ; rank < pmap -> size ; rank ++ ) {
445+ if (pmap -> u .map [rank ] == lpid ) {
446+ return rank ;
447+ }
448+ }
449+ return MPI_UNDEFINED ;
450+ } else {
451+ lpid -= pmap -> u .stride .offset ;
452+ MPIR_Lpid i_blk = lpid / pmap -> u .stride .stride ;
453+ MPIR_Lpid r_blk = lpid % pmap -> u .stride .stride ;
454+ /* NOTE: stride could be negative, in which case, make sure r_blk >= 0 */
455+ if (r_blk < 0 ) {
456+ MPIR_Assert (pmap -> u .stride .stride < 0 );
457+ r_blk -= pmap -> u .stride .stride ;
458+ i_blk += 1 ;
459+ }
460+
461+ if (i_blk < 0 ) {
462+ return MPI_UNDEFINED ;
463+ }
464+
465+ if (r_blk >= pmap -> u .stride .blocksize ) {
466+ return MPI_UNDEFINED ;
467+ }
468+
469+ int rank = i_blk * pmap -> u .stride .blocksize + r_blk ;
470+ if (rank >= 0 && rank < pmap -> size ) {
471+ return rank ;
472+ } else {
473+ return MPI_UNDEFINED ;
474+ }
475+ }
476+ }
0 commit comments