Skip to content

Commit d5d81c7

Browse files
committed
group: optimize grouputil
* Add check_map_is_strided to detect strided pattern and convert a map into a strided pmap. * Move internal static routines to the bottom of grouputil.c.
1 parent 37bf832 commit d5d81c7

File tree

1 file changed

+124
-47
lines changed

1 file changed

+124
-47
lines changed

src/mpi/group/grouputil.c

Lines changed: 124 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ int MPIR_Group_create(int nproc, MPIR_Group ** new_group_ptr)
8585
return mpi_errno;
8686
}
8787

88+
static bool check_map_is_strided(int size, MPIR_Lpid * map,
89+
MPIR_Lpid * offset_out, MPIR_Lpid * stride_out,
90+
MPIR_Lpid * blocksize_out);
91+
8892
int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_Lpid * map,
8993
MPIR_Group ** new_group_ptr)
9094
{
@@ -104,10 +108,16 @@ int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_L
104108
newgrp->rank = rank;
105109
MPIR_Group_set_session_ptr(newgrp, session_ptr);
106110

107-
newgrp->pmap.use_map = true;
108-
newgrp->pmap.u.map = map;
111+
if (check_map_is_strided(size, map, &newgrp->pmap.u.stride.offset,
112+
&newgrp->pmap.u.stride.stride, &newgrp->pmap.u.stride.blocksize)) {
113+
newgrp->pmap.use_map = false;
114+
MPL_free(map);
115+
} else {
116+
newgrp->pmap.use_map = true;
117+
newgrp->pmap.u.map = map;
118+
/* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
119+
}
109120

110-
/* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
111121
*new_group_ptr = newgrp;
112122
}
113123

@@ -155,50 +165,8 @@ int MPIR_Group_create_stride(int size, int rank, MPIR_Session * session_ptr,
155165
goto fn_exit;
156166
}
157167

158-
static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank)
159-
{
160-
if (rank < 0 || rank >= pmap->size) {
161-
return MPI_UNDEFINED;
162-
}
163-
164-
if (pmap->use_map) {
165-
return pmap->u.map[rank];
166-
} else {
167-
MPIR_Lpid i_blk = rank / pmap->u.stride.blocksize;
168-
MPIR_Lpid r_blk = rank % pmap->u.stride.blocksize;
169-
return pmap->u.stride.offset + i_blk * pmap->u.stride.stride + r_blk;
170-
}
171-
}
172-
173-
static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid)
174-
{
175-
if (pmap->use_map) {
176-
/* Use linear search for now.
177-
* Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
178-
*/
179-
for (int rank = 0; rank < pmap->size; rank++) {
180-
if (pmap->u.map[rank] == lpid) {
181-
return rank;
182-
}
183-
}
184-
return MPI_UNDEFINED;
185-
} else {
186-
lpid -= pmap->u.stride.offset;
187-
MPIR_Lpid i_blk = lpid / pmap->u.stride.stride;
188-
MPIR_Lpid r_blk = lpid % pmap->u.stride.stride;
189-
190-
if (r_blk >= pmap->u.stride.blocksize) {
191-
return MPI_UNDEFINED;
192-
}
193-
194-
int rank = i_blk * pmap->u.stride.blocksize + r_blk;
195-
if (rank >= 0 && rank < pmap->size) {
196-
return rank;
197-
} else {
198-
return MPI_UNDEFINED;
199-
}
200-
}
201-
}
168+
static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid);
169+
static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank);
202170

203171
int MPIR_Group_lpid_to_rank(MPIR_Group * group, MPIR_Lpid lpid)
204172
{
@@ -397,3 +365,112 @@ void MPIR_Group_set_session_ptr(MPIR_Group * group_ptr, MPIR_Session * session_p
397365
MPIR_Session_add_ref(session_ptr);
398366
}
399367
}
368+
369+
/* internal static routines */
370+
371+
static bool check_map_is_strided(int size, MPIR_Lpid * map,
372+
MPIR_Lpid * offset_out, MPIR_Lpid * stride_out,
373+
MPIR_Lpid * blocksize_out)
374+
{
375+
MPIR_Assert(size > 0);
376+
if (size == 1) {
377+
*offset_out = map[0];
378+
*stride_out = 1;
379+
*blocksize_out = 1;
380+
return true;
381+
} else {
382+
MPIR_Lpid offset, stride, blocksize;
383+
offset = map[0];
384+
385+
blocksize = 1;
386+
for (int i = 1; i < size; i++) {
387+
if (map[i] - map[i - 1] == 1) {
388+
blocksize++;
389+
} else {
390+
break;
391+
}
392+
}
393+
if (blocksize == size) {
394+
/* consecutive */
395+
*offset_out = offset;
396+
*stride_out = 1;
397+
*blocksize_out = 1;
398+
return true;
399+
} else {
400+
/* NOTE: stride may be negative */
401+
stride = map[blocksize] - map[0];
402+
int n_strides = (size + blocksize - 1) / blocksize;
403+
int k = 0;
404+
for (int i = 0; i < n_strides; i++) {
405+
for (int j = 0; j < blocksize; j++) {
406+
if (map[k] != offset + i * stride + j) {
407+
return false;
408+
}
409+
k++;
410+
if (k == size) {
411+
break;
412+
}
413+
}
414+
}
415+
*offset_out = offset;
416+
*stride_out = stride;
417+
*blocksize_out = blocksize;
418+
return true;
419+
}
420+
}
421+
}
422+
423+
static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank)
424+
{
425+
if (rank < 0 || rank >= pmap->size) {
426+
return MPI_UNDEFINED;
427+
}
428+
429+
if (pmap->use_map) {
430+
return pmap->u.map[rank];
431+
} else {
432+
MPIR_Lpid i_blk = rank / pmap->u.stride.blocksize;
433+
MPIR_Lpid r_blk = rank % pmap->u.stride.blocksize;
434+
return pmap->u.stride.offset + i_blk * pmap->u.stride.stride + r_blk;
435+
}
436+
}
437+
438+
static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid)
439+
{
440+
if (pmap->use_map) {
441+
/* Use linear search for now.
442+
* Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
443+
*/
444+
for (int rank = 0; rank < pmap->size; rank++) {
445+
if (pmap->u.map[rank] == lpid) {
446+
return rank;
447+
}
448+
}
449+
return MPI_UNDEFINED;
450+
} else {
451+
lpid -= pmap->u.stride.offset;
452+
MPIR_Lpid i_blk = lpid / pmap->u.stride.stride;
453+
MPIR_Lpid r_blk = lpid % pmap->u.stride.stride;
454+
/* NOTE: stride could be negative, in which case, make sure r_blk >= 0 */
455+
if (r_blk < 0) {
456+
MPIR_Assert(pmap->u.stride.stride < 0);
457+
r_blk -= pmap->u.stride.stride;
458+
i_blk += 1;
459+
}
460+
461+
if (i_blk < 0) {
462+
return MPI_UNDEFINED;
463+
}
464+
465+
if (r_blk >= pmap->u.stride.blocksize) {
466+
return MPI_UNDEFINED;
467+
}
468+
469+
int rank = i_blk * pmap->u.stride.blocksize + r_blk;
470+
if (rank >= 0 && rank < pmap->size) {
471+
return rank;
472+
} else {
473+
return MPI_UNDEFINED;
474+
}
475+
}
476+
}

0 commit comments

Comments
 (0)