Skip to content

Commit e5cc49a

Browse files
committed
Collect size information from processes if it's not the same everywhere
Hook into the disp_unit handling to exchange size information if needed. Signed-off-by: Joseph Schuchart <[email protected]>
1 parent a15d221 commit e5cc49a

File tree

2 files changed

+59
-39
lines changed

2 files changed

+59
-39
lines changed

ompi/mca/osc/ucx/osc_ucx.h

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,32 +116,19 @@ typedef struct ompi_osc_ucx_mem_ranges {
116116
uint64_t tail;
117117
} ompi_osc_ucx_mem_ranges_t;
118118

119-
/**
120-
* Structure to hold information about shared memory regions.
121-
* We store the rank, it's address, and the size of the window region.
122-
* We don't store the disp_unit here, as that is stored elsewhere already.
123-
*/
124-
struct ompi_osc_ucx_shmem_info_s {
125-
int peer; /* rank of the peer this information belongs to */
126-
char *addr; /* address of the shared memory region */
127-
size_t size; /* size of the shared memory region */
128-
};
129-
130-
typedef struct ompi_osc_ucx_shmem_info_s ompi_osc_ucx_shmem_info_t;
131-
132119
typedef struct ompi_osc_ucx_module {
133120
ompi_osc_base_module_t super;
134121
struct ompi_communicator_t *comm;
135122
int flavor;
136-
size_t size;
123+
size_t size;
124+
size_t *sizes; /* used if !same_size*/
137125
uint64_t *addrs;
138126
uint64_t *state_addrs;
139127
uint64_t *comm_world_ranks;
140128
ptrdiff_t disp_unit; /* if disp_unit >= 0, then everyone has the same
141129
* disp unit size; if disp_unit == -1, then we
142130
* need to look at disp_units */
143131
ptrdiff_t *disp_units;
144-
ompi_osc_ucx_shmem_info_t *shmem_info; /* shared memory info */
145132

146133
ompi_osc_ucx_state_t state; /* remote accessible flags */
147134
ompi_osc_local_dynamic_win_info_t local_dynamic_win_info[OMPI_OSC_UCX_ATTACH_MAX];
@@ -157,6 +144,7 @@ typedef struct ompi_osc_ucx_module {
157144
bool lock_all_is_nocheck;
158145
bool no_locks;
159146
bool acc_single_intrinsic;
147+
bool same_size;
160148
opal_common_ucx_ctx_t *ctx;
161149
opal_common_ucx_wpmem_t *mem;
162150
opal_common_ucx_wpmem_t *state_mem;

ompi/mca/osc/ucx/osc_ucx_component.c

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ static int ompi_osc_ucx_shared_query_peer(ompi_osc_ucx_module_t *module, int pee
487487
if (UCS_OK != ucp_rkey_ptr(rkey, module->addrs[peer], &addr_p)) {
488488
return OMPI_ERR_NOT_AVAILABLE;
489489
}
490-
*size = module->sizes[peer];
491-
*((void**) baseptr) = (void *)module->shmem_addrs[peer];
492-
*disp_unit = module->disp_units[peer];
490+
*size = module->same_size ? module->size : module->sizes[peer];
491+
*((void**) baseptr) = addr_p;
492+
*disp_unit = (module->disp_unit < 0) ? module->disp_units[peer] : module->disp_unit;
493493

494494
return OMPI_SUCCESS;
495495
}
@@ -554,8 +554,9 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
554554
int flavor, int *model) {
555555
ompi_osc_ucx_module_t *module = NULL;
556556
char *name = NULL;
557-
long values[2];
557+
long values[4];
558558
int ret = OMPI_SUCCESS;
559+
int val_count = 0;
559560
int i, comm_size = ompi_comm_size(comm);
560561
bool env_initialized = false;
561562
void *state_base = NULL;
@@ -679,42 +680,70 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
679680
module->acc_single_intrinsic = check_config_value_bool ("acc_single_intrinsic", info);
680681
module->skip_sync_check = false;
681682

682-
/**
683-
* TODO: we need to collect the shared memory information from all processes
684-
* on the same node. This includes the size and base address, which needs
685-
* to be passed to ucp_rkey_ptr().
686-
*/
687-
module->shmem_info = NULL;
688-
689683
/* share everyone's displacement units. Only do an allgather if
690684
strictly necessary, since it requires O(p) state. */
691685
values[0] = disp_unit;
692686
values[1] = -disp_unit;
687+
values[2] = size;
688+
values[3] = -(int64_t)size;
693689

694-
ret = module->comm->c_coll->coll_allreduce(MPI_IN_PLACE, values, 2, MPI_LONG,
690+
ret = module->comm->c_coll->coll_allreduce(MPI_IN_PLACE, values, 4, MPI_LONG,
695691
MPI_MIN, module->comm,
696692
module->comm->c_coll->coll_allreduce_module);
697693
if (OMPI_SUCCESS != ret) {
698694
goto error;
699695
}
700696

701-
if (values[0] == -values[1]) { /* everyone has the same disp_unit, we do not need O(p) space */
697+
bool same_disp_unit = (values[0] == -values[1]);
698+
bool same_size = (values[2] == -values[3]);
699+
700+
if (same_disp_unit) { /* everyone has the same disp_unit, we do not need O(p) space */
702701
module->disp_unit = disp_unit;
703-
} else { /* different disp_unit sizes, allocate O(p) space to store them */
704-
module->disp_unit = -1;
705-
module->disp_units = calloc(comm_size, sizeof(ptrdiff_t));
706-
if (module->disp_units == NULL) {
707-
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
708-
goto error;
709-
}
702+
module->disp_units = NULL;
703+
values[val_count++] = disp_unit;
704+
}
705+
706+
if (same_size) {
707+
module->same_size = true;
708+
module->sizes = NULL;
709+
values[val_count++] = size;
710+
}
711+
712+
if (!same_disp_unit || !same_size) {
710713

711-
ret = module->comm->c_coll->coll_allgather(&disp_unit, sizeof(ptrdiff_t), MPI_BYTE,
712-
module->disp_units, sizeof(ptrdiff_t), MPI_BYTE,
713-
module->comm,
714-
module->comm->c_coll->coll_allgather_module);
714+
ret = module->comm->c_coll->coll_allgather(values, val_count * sizeof(long), MPI_BYTE,
715+
(void *)my_info, sizeof(long) * val_count, MPI_BYTE,
716+
module->comm,
717+
module->comm->c_coll->coll_allgather_module);
715718
if (OMPI_SUCCESS != ret) {
716719
goto error;
717720
}
721+
722+
if (!same_disp_unit) { /* everyone has a different disp_unit */
723+
module->disp_unit = -1;
724+
module->disp_units = calloc(comm_size, sizeof(ptrdiff_t));
725+
if (module->disp_units == NULL) {
726+
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
727+
goto error;
728+
}
729+
for (i = 0; i < comm_size; i++) {
730+
module->disp_units[i] = (ptrdiff_t)values[i*val_count];
731+
}
732+
}
733+
734+
if (!same_size) { /* everyone has the same disp_unit, we do not need O(p) space */
735+
module->same_size = false;
736+
module->sizes = calloc(comm_size, sizeof(size_t));
737+
if (module->sizes == NULL) {
738+
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
739+
goto error;
740+
}
741+
742+
for (i = 0; i < comm_size; i++) {
743+
module->sizes[i] = (size_t)values[i*val_count + val_count-1];
744+
}
745+
}
746+
718747
}
719748

720749
ret = opal_common_ucx_wpctx_create(mca_osc_ucx_component.wpool, comm_size,
@@ -1261,6 +1290,9 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) {
12611290
if (module->disp_units) {
12621291
free(module->disp_units);
12631292
}
1293+
if (module->sizes) {
1294+
free(module->sizes);
1295+
}
12641296
ompi_comm_free(&module->comm);
12651297

12661298
free(module);

0 commit comments

Comments
 (0)