Skip to content

Commit 2419c23

Browse files
committed
OMPI/OSC/UCX: fix issue in impl of MPI_Win_create_dynamic/MPI_Win_attach/MPI_Win_detach
Signed-off-by: Xin Zhao <[email protected]>
1 parent 5544367 commit 2419c23

File tree

3 files changed

+301
-28
lines changed

3 files changed

+301
-28
lines changed

ompi/mca/osc/ucx/osc_ucx.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
#include "ompi/communicator/communicator.h"
1717

1818
#define OMPI_OSC_UCX_POST_PEER_MAX 32
19+
#define OMPI_OSC_UCX_ATTACH_MAX 32
20+
#define OMPI_OSC_UCX_RKEY_BUF_MAX 1024
1921

2022
typedef struct ompi_osc_ucx_win_info {
2123
ucp_rkey_h rkey;
2224
uint64_t addr;
25+
bool rkey_init;
2326
} ompi_osc_ucx_win_info_t;
2427

2528
typedef struct ompi_osc_ucx_component {
@@ -59,6 +62,18 @@ typedef struct ompi_osc_ucx_epoch_type {
5962
#define OSC_UCX_STATE_COMPLETE_COUNT_OFFSET (sizeof(uint64_t) * 3)
6063
#define OSC_UCX_STATE_POST_INDEX_OFFSET (sizeof(uint64_t) * 4)
6164
#define OSC_UCX_STATE_POST_STATE_OFFSET (sizeof(uint64_t) * 5)
65+
#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (5 + OMPI_OSC_UCX_POST_PEER_MAX))
66+
67+
typedef struct ompi_osc_dynamic_win_info {
68+
uint64_t base;
69+
size_t size;
70+
char rkey_buffer[OMPI_OSC_UCX_RKEY_BUF_MAX];
71+
} ompi_osc_dynamic_win_info_t;
72+
73+
typedef struct ompi_osc_local_dynamic_win_info {
74+
ucp_mem_h memh;
75+
int refcnt;
76+
} ompi_osc_local_dynamic_win_info_t;
6277

6378
typedef struct ompi_osc_ucx_state {
6479
volatile uint64_t lock;
@@ -67,12 +82,16 @@ typedef struct ompi_osc_ucx_state {
6782
volatile uint64_t complete_count; /* # msgs received from complete processes */
6883
volatile uint64_t post_index;
6984
volatile uint64_t post_state[OMPI_OSC_UCX_POST_PEER_MAX];
85+
volatile uint64_t dynamic_win_count;
86+
volatile ompi_osc_dynamic_win_info_t dynamic_wins[OMPI_OSC_UCX_ATTACH_MAX];
7087
} ompi_osc_ucx_state_t;
7188

7289
typedef struct ompi_osc_ucx_module {
7390
ompi_osc_base_module_t super;
7491
struct ompi_communicator_t *comm;
7592
ucp_mem_h memh; /* remote accessible memory */
93+
int flavor;
94+
size_t size;
7695
ucp_mem_h state_memh;
7796
ompi_osc_ucx_win_info_t *win_info_array;
7897
ompi_osc_ucx_win_info_t *state_info_array;
@@ -82,6 +101,7 @@ typedef struct ompi_osc_ucx_module {
82101
int *disp_units;
83102

84103
ompi_osc_ucx_state_t state; /* remote accessible flags */
104+
ompi_osc_local_dynamic_win_info_t local_dynamic_win_info[OMPI_OSC_UCX_ATTACH_MAX];
85105
ompi_osc_ucx_epoch_type_t epoch_type;
86106
ompi_group_t *start_group;
87107
ompi_group_t *post_group;
@@ -184,6 +204,10 @@ int ompi_osc_ucx_flush_all(struct ompi_win_t *win);
184204
int ompi_osc_ucx_flush_local(int target, struct ompi_win_t *win);
185205
int ompi_osc_ucx_flush_local_all(struct ompi_win_t *win);
186206

207+
int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_wins,
208+
int min_index, int max_index,
209+
uint64_t base, size_t len, int *insert);
210+
187211
void req_completion(void *request, ucs_status_t status);
188212
void internal_req_init(void *request);
189213

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,68 @@ static inline int end_atomicity(ompi_osc_ucx_module_t *module, ucp_ep_h ep, int
325325
return OMPI_SUCCESS;
326326
}
327327

328+
static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module_t *module,
329+
ucp_ep_h ep, int target) {
330+
ucp_rkey_h state_rkey = (module->state_info_array)[target].rkey;
331+
uint64_t remote_state_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET;
332+
size_t len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX;
333+
char *temp_buf = malloc(len);
334+
ompi_osc_dynamic_win_info_t *temp_dynamic_wins;
335+
int win_count, contain, insert = -1;
336+
ucs_status_t status;
337+
338+
if ((module->win_info_array[target]).rkey_init == true) {
339+
ucp_rkey_destroy((module->win_info_array[target]).rkey);
340+
(module->win_info_array[target]).rkey_init == false;
341+
}
342+
343+
status = ucp_get_nbi(ep, (void *)temp_buf, len, remote_state_addr, state_rkey);
344+
if (status != UCS_OK && status != UCS_INPROGRESS) {
345+
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
346+
"%s:%d: ucp_get_nbi failed: %d\n",
347+
__FILE__, __LINE__, status);
348+
return OMPI_ERROR;
349+
}
350+
351+
status = ucp_ep_flush(ep);
352+
if (status != UCS_OK) {
353+
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
354+
"%s:%d: ucp_ep_flush failed: %d\n",
355+
__FILE__, __LINE__, status);
356+
return OMPI_ERROR;
357+
}
358+
359+
memcpy(&win_count, temp_buf, sizeof(uint64_t));
360+
assert(win_count > 0 && win_count <= OMPI_OSC_UCX_ATTACH_MAX);
361+
362+
temp_dynamic_wins = (ompi_osc_dynamic_win_info_t *)(temp_buf + sizeof(uint64_t));
363+
contain = ompi_osc_find_attached_region_position(temp_dynamic_wins, 0, win_count,
364+
remote_addr, 1, &insert);
365+
assert(contain >= 0 && contain < win_count);
366+
367+
status = ucp_ep_rkey_unpack(ep, temp_dynamic_wins[contain].rkey_buffer,
368+
&((module->win_info_array[target]).rkey));
369+
if (status != UCS_OK) {
370+
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
371+
"%s:%d: ucp_ep_rkey_unpack failed: %d\n",
372+
__FILE__, __LINE__, status);
373+
return OMPI_ERROR;
374+
}
375+
376+
(module->win_info_array[target]).rkey_init = true;
377+
378+
free(temp_buf);
379+
380+
return status;
381+
}
382+
328383
int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt,
329384
int target, ptrdiff_t target_disp, int target_count,
330385
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
331386
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
332387
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
333388
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
334-
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
389+
ucp_rkey_h rkey;
335390
bool is_origin_contig = false, is_target_contig = false;
336391
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
337392
ucs_status_t status;
@@ -342,6 +397,15 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
342397
return ret;
343398
}
344399

400+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
401+
status = get_dynamic_win_info(remote_addr, module, ep, target);
402+
if (status != UCS_OK) {
403+
return OMPI_ERROR;
404+
}
405+
}
406+
407+
rkey = (module->win_info_array[target]).rkey;
408+
345409
ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
346410
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
347411

@@ -378,7 +442,7 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
378442
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
379443
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
380444
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
381-
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
445+
ucp_rkey_h rkey;
382446
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
383447
bool is_origin_contig = false, is_target_contig = false;
384448
ucs_status_t status;
@@ -389,6 +453,15 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
389453
return ret;
390454
}
391455

456+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
457+
status = get_dynamic_win_info(remote_addr, module, ep, target);
458+
if (status != UCS_OK) {
459+
return OMPI_ERROR;
460+
}
461+
}
462+
463+
rkey = (module->win_info_array[target]).rkey;
464+
392465
ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
393466
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
394467

@@ -557,10 +630,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
557630
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
558631
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
559632
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
560-
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
633+
ucp_rkey_h rkey;
561634
size_t dt_bytes;
562635
ompi_osc_ucx_internal_request_t *req = NULL;
563636
int ret = OMPI_SUCCESS;
637+
ucs_status_t status;
564638

565639
ret = check_sync_state(module, target, false);
566640
if (ret != OMPI_SUCCESS) {
@@ -572,6 +646,15 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
572646
return ret;
573647
}
574648

649+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
650+
status = get_dynamic_win_info(remote_addr, module, ep, target);
651+
if (status != UCS_OK) {
652+
return OMPI_ERROR;
653+
}
654+
}
655+
656+
rkey = (module->win_info_array[target]).rkey;
657+
575658
ompi_datatype_type_size(dt, &dt_bytes);
576659
memcpy(result_addr, origin_addr, dt_bytes);
577660
req = ucp_atomic_fetch_nb(ep, UCP_ATOMIC_FETCH_OP_CSWAP, *(uint64_t *)compare_addr,
@@ -604,17 +687,27 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
604687
op == &ompi_mpi_op_sum.op) {
605688
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
606689
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
607-
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
690+
ucp_rkey_h rkey;
608691
uint64_t value = *(uint64_t *)origin_addr;
609692
ucp_atomic_fetch_op_t opcode;
610693
size_t dt_bytes;
611694
ompi_osc_ucx_internal_request_t *req = NULL;
695+
ucs_status_t status;
612696

613697
ret = start_atomicity(module, ep, target);
614698
if (ret != OMPI_SUCCESS) {
615699
return ret;
616700
}
617701

702+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
703+
status = get_dynamic_win_info(remote_addr, module, ep, target);
704+
if (status != UCS_OK) {
705+
return OMPI_ERROR;
706+
}
707+
}
708+
709+
rkey = (module->win_info_array[target]).rkey;
710+
618711
ompi_datatype_type_size(dt, &dt_bytes);
619712

620713
if (op == &ompi_mpi_op_replace.op) {
@@ -789,7 +882,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
789882
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
790883
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
791884
uint64_t remote_addr = (module->state_info_array[target]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET;
792-
ucp_rkey_h rkey = (module->state_info_array[target]).rkey;
885+
ucp_rkey_h rkey;
793886
ompi_osc_ucx_request_t *ucx_req = NULL;
794887
ompi_osc_ucx_internal_request_t *internal_req = NULL;
795888
ucs_status_t status;
@@ -800,6 +893,15 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
800893
return ret;
801894
}
802895

896+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
897+
status = get_dynamic_win_info(remote_addr, module, ep, target);
898+
if (status != UCS_OK) {
899+
return OMPI_ERROR;
900+
}
901+
}
902+
903+
rkey = (module->win_info_array[target]).rkey;
904+
803905
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
804906
if (NULL == ucx_req) {
805907
return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
@@ -843,7 +945,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
843945
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
844946
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
845947
uint64_t remote_addr = (module->state_info_array[target]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET;
846-
ucp_rkey_h rkey = (module->state_info_array[target]).rkey;
948+
ucp_rkey_h rkey;
847949
ompi_osc_ucx_request_t *ucx_req = NULL;
848950
ompi_osc_ucx_internal_request_t *internal_req = NULL;
849951
ucs_status_t status;
@@ -854,6 +956,15 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
854956
return ret;
855957
}
856958

959+
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
960+
status = get_dynamic_win_info(remote_addr, module, ep, target);
961+
if (status != UCS_OK) {
962+
return OMPI_ERROR;
963+
}
964+
}
965+
966+
rkey = (module->win_info_array[target]).rkey;
967+
857968
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
858969
if (NULL == ucx_req) {
859970
return OMPI_ERR_TEMP_OUT_OF_RESOURCE;

0 commit comments

Comments
 (0)