Skip to content

Commit c80ccaf

Browse files
committed
ch4/spawn: enter lock in ch4-layer for dynamic apis
The dynamic exchange takes a few rounds of dynamic send/recv and it is susceptible to get interleaved when multiple threads are concurrently trying to establish dynamic connections. The use of tag may not be sufficient. Entering lock at ch4-level for the entire leader exchange stage ensures the thread-safety. To make the rules simpler, we require all netmod functions defined in ofi_spawn.c and ucx_spawn.c to lock by caller at ch4-layer. * add MPID_THREAD_ASSERT_IN_CS in all netmod spawn functions to ensure we don't neglect the lock. * replace and remove MPIDI_OFI_VCI_CALL macro. * replace and remove MPIDIU_upids_to_lpids by directly call MPIDI_NM_upids_to_lpids.
1 parent 6e9cce0 commit c80ccaf

File tree

7 files changed

+66
-49
lines changed

7 files changed

+66
-49
lines changed

src/mpid/ch4/netmod/ofi/ofi_impl.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,6 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret);
171171
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_)); \
172172
} while (0)
173173

174-
#define MPIDI_OFI_VCI_CALL(FUNC,vci_,STR) \
175-
do { \
176-
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci_)); \
177-
ssize_t _ret = FUNC; \
178-
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_)); \
179-
MPIDI_OFI_ERR(_ret<0, \
180-
mpi_errno, \
181-
MPI_ERR_OTHER, \
182-
"**ofid_"#STR, \
183-
"**ofid_"#STR" %s %s", \
184-
MPIDI_OFI_DEFAULT_NIC_NAME, \
185-
fi_strerror(-_ret)); \
186-
} while (0)
187-
188174
#define MPIDI_OFI_THREAD_CS_ENTER_VCI_OPTIONAL(vci_) \
189175
do { \
190176
if (!MPIDI_VCI_IS_EXPLICIT(vci_) && MPIDI_CH4_MT_MODEL != MPIDI_CH4_MT_LOCKLESS) { \

src/mpid/ch4/netmod/ofi/ofi_spawn.c

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ int MPIDI_OFI_dynamic_send(MPIR_Lpid remote_lpid, int tag, const void *buf, int
1616
int mpi_errno = MPI_SUCCESS;
1717

1818
MPIR_Assert(MPIDI_OFI_ENABLE_TAGGED);
19+
#ifdef MPICH_DEBUG_MUTEX
20+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
21+
#endif
1922

2023
int vci = 0; /* dynamic process only use vci 0 */
2124
int ctx_idx = 0;
@@ -70,6 +73,9 @@ int MPIDI_OFI_dynamic_recv(int tag, void *buf, int size, int timeout)
7073
int mpi_errno = MPI_SUCCESS;
7174

7275
MPIR_Assert(MPIDI_OFI_ENABLE_TAGGED);
76+
#ifdef MPICH_DEBUG_MUTEX
77+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
78+
#endif
7379

7480
int vci = 0; /* dynamic process only use vci 0 */
7581
int ctx_idx = 0;
@@ -238,6 +244,9 @@ int MPIDI_OFI_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids
238244
int ctx_idx = MPIDI_OFI_get_ctx_index(0, 0);
239245

240246
MPIR_CHKLMEM_DECL();
247+
#ifdef MPICH_DEBUG_MUTEX
248+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
249+
#endif
241250

242251
MPIR_CHKLMEM_MALLOC(new_avt_procs, sizeof(int) * size);
243252
MPIR_CHKLMEM_MALLOC(new_upids, sizeof(char *) * size);
@@ -263,9 +272,8 @@ int MPIDI_OFI_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids
263272
for (j = 0; j < MPIDIU_get_av_table(k)->size; j++) {
264273
sz = MPIDI_OFI_global.addrnamelen;
265274
MPIDI_av_entry_t *av = &MPIDIU_get_av(k, j);
266-
MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av,
267-
MPIDI_OFI_AV_ADDR_ROOT(av), &tbladdr, &sz), 0,
268-
avlookup);
275+
MPIDI_OFI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av,
276+
MPIDI_OFI_AV_ADDR_ROOT(av), &tbladdr, &sz), avlookup);
269277
if (sz == addrname_len && !memcmp(tbladdr, addrname, addrname_len)) {
270278
remote_lpids[i] = MPIDIU_GPID_CREATE(k, j);
271279
found = 1;
@@ -296,8 +304,8 @@ int MPIDI_OFI_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids
296304
char *addrname = hostname + strlen(hostname) + 1;
297305

298306
fi_addr_t addr;
299-
MPIDI_OFI_VCI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[ctx_idx].av, addrname,
300-
1, &addr, 0ULL, NULL), 0, avmap);
307+
MPIDI_OFI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[ctx_idx].av, addrname,
308+
1, &addr, 0ULL, NULL), avmap);
301309
MPIR_Assert(addr != FI_ADDR_NOTAVAIL);
302310
MPIDI_OFI_AV_ADDR_ROOT(&MPIDIU_get_av(avtid, i)) = addr;
303311

@@ -325,6 +333,9 @@ int MPIDI_OFI_get_local_upids(MPIR_Comm * comm, int **local_upid_size, char **lo
325333
int ctx_idx = MPIDI_OFI_get_ctx_index(0, 0);
326334

327335
MPIR_CHKPMEM_DECL();
336+
#ifdef MPICH_DEBUG_MUTEX
337+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
338+
#endif
328339

329340
MPIR_CHKPMEM_MALLOC((*local_upid_size), comm->local_size * sizeof(int), MPL_MEM_ADDRESS);
330341
MPIR_CHKPMEM_MALLOC(temp_buf, comm->local_size * MPIDI_OFI_global.addrnamelen, MPL_MEM_BUFFER);
@@ -352,9 +363,8 @@ int MPIDI_OFI_get_local_upids(MPIR_Comm * comm, int **local_upid_size, char **lo
352363

353364
size_t sz = MPIDI_OFI_global.addrnamelen;;
354365
MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, i);
355-
MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av,
356-
MPIDI_OFI_AV_ADDR_ROOT(av),
357-
temp_buf + idx, &sz), 0, avlookup);
366+
MPIDI_OFI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av,
367+
MPIDI_OFI_AV_ADDR_ROOT(av), temp_buf + idx, &sz), avlookup);
358368
idx += (int) sz;
359369

360370
(*local_upid_size)[i] = upid_len;
@@ -373,6 +383,9 @@ int MPIDI_OFI_insert_upid(MPIR_Lpid lpid, const char *upid, int upid_len)
373383
{
374384
int mpi_errno = MPI_SUCCESS;
375385

386+
#ifdef MPICH_DEBUG_MUTEX
387+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
388+
#endif
376389
const char *hostname = upid;
377390
MPIDI_av_entry_t *av = MPIDIU_lpid_to_av_slow(lpid);
378391

src/mpid/ch4/netmod/ucx/ucx_spawn.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ int MPIDI_UCX_dynamic_send(MPIR_Lpid remote_lpid, int tag, const void *buf, int
2828

2929
uint64_t ucx_tag = MPIDI_UCX_DYNPROC_MASK + tag;
3030
int vci = 0;
31+
#ifdef MPICH_DEBUG_MUTEX
32+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(vci));
33+
#endif
3134

3235
ucp_ep_h ep = MPIDI_UCX_AV_TO_EP(MPIDIU_lpid_to_av_slow(remote_lpid), vci, vci);
3336

@@ -76,6 +79,9 @@ int MPIDI_UCX_dynamic_recv(int tag, void *buf, int size, int timeout)
7679
uint64_t ucx_tag = MPIDI_UCX_DYNPROC_MASK + tag;
7780
uint64_t tag_mask = 0xffffffffffffffff;
7881
int vci = 0;
82+
#ifdef MPICH_DEBUG_MUTEX
83+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(vci));
84+
#endif
7985

8086
bool done = false;
8187
ucp_request_param_t param = {
@@ -206,6 +212,9 @@ int MPIDI_UCX_get_local_upids(MPIR_Comm * comm, int **local_upid_size, char **lo
206212
{
207213
int mpi_errno = MPI_SUCCESS;
208214

215+
#ifdef MPICH_DEBUG_MUTEX
216+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
217+
#endif
209218
MPIR_CHKPMEM_DECL();
210219
MPIR_CHKPMEM_MALLOC((*local_upid_size), comm->local_size * sizeof(int), MPL_MEM_ADDRESS);
211220
MPIR_CHKPMEM_MALLOC((*local_upids), comm->local_size * MPID_MAX_BC_SIZE, MPL_MEM_BUFFER);
@@ -230,6 +239,9 @@ int MPIDI_UCX_insert_upid(MPIR_Lpid lpid, const char *upid, int upid_len)
230239
int mpi_errno = MPI_SUCCESS;
231240
MPIDI_av_entry_t *av = MPIDIU_lpid_to_av_slow(lpid);
232241

242+
#ifdef MPICH_DEBUG_MUTEX
243+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
244+
#endif
233245
bool is_dynamic = (lpid & MPIR_LPID_DYNAMIC_MASK);
234246
bool do_insert = false;
235247
if (is_dynamic) {
@@ -278,6 +290,9 @@ int MPIDI_UCX_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids
278290
int vci = 0;
279291
MPIR_CHKLMEM_DECL();
280292

293+
#ifdef MPICH_DEBUG_MUTEX
294+
MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0));
295+
#endif
281296
MPIR_CHKLMEM_MALLOC(new_avt_procs, sizeof(int) * size);
282297
MPIR_CHKLMEM_MALLOC(new_upids, sizeof(char *) * size);
283298

src/mpid/ch4/src/ch4_comm.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,9 @@ int MPID_Intercomm_exchange(MPIR_Comm * local_comm, int local_leader,
496496
/* Stage 1.1 UPID exchange between leaders */
497497
MPIR_CHKLMEM_MALLOC(remote_upid_size, (*remote_size) * sizeof(int));
498498

499+
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0));
499500
mpi_errno = MPIDI_NM_get_local_upids(local_comm, &local_upid_size, &local_upids);
501+
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0));
500502
MPIR_ERR_CHECK(mpi_errno);
501503
mpi_errno = MPIC_Sendrecv(local_upid_size, local_size, MPIR_INT_INTERNAL,
502504
remote_leader, tag,
@@ -519,7 +521,11 @@ int MPID_Intercomm_exchange(MPIR_Comm * local_comm, int local_leader,
519521
MPIR_ERR_CHECK(mpi_errno);
520522

521523
/* Stage 1.2 convert remote UPID to GPID and get GPID for local group */
522-
MPIDIU_upids_to_lpids(*remote_size, remote_upid_size, remote_upids, *remote_lpids);
524+
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0));
525+
mpi_errno = MPIDI_NM_upids_to_lpids(*remote_size, remote_upid_size, remote_upids,
526+
*remote_lpids);
527+
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0));
528+
MPIR_ERR_CHECK(mpi_errno);
523529
} else {
524530
/* Stage 1.1f only exchange GPIDS if no dynamic process involved */
525531
MPI_Aint local_bytes = sizeof(local_lpids[0]) * local_size;
@@ -649,7 +655,11 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i
649655
local_leader, local_comm, MPIR_ERR_NONE);
650656
MPIR_ERR_CHECK(mpi_errno);
651657

652-
MPIDIU_upids_to_lpids(*remote_size, _remote_upid_size, _remote_upids, *remote_lpids);
658+
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0));
659+
mpi_errno = MPIDI_NM_upids_to_lpids(*remote_size, _remote_upid_size, _remote_upids,
660+
*remote_lpids);
661+
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0));
662+
MPIR_ERR_CHECK(mpi_errno);
653663
} else {
654664
MPI_Aint remote_bytes = sizeof(MPIR_Lpid) * (*remote_size);
655665
mpi_errno = MPIR_Bcast_allcomm_auto(*remote_lpids, remote_bytes, MPIR_BYTE_INTERNAL,

src/mpid/ch4/src/ch4_proc.c

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -376,27 +376,6 @@ void MPIDIU_upidhash_free(void)
376376
}
377377
#endif
378378

379-
/* convert upid to gpid by netmod.
380-
* For ofi netmod, it inserts the address and fills an av entry.
381-
*/
382-
int MPIDIU_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids,
383-
MPIR_Lpid * remote_lpids)
384-
{
385-
int mpi_errno = MPI_SUCCESS;
386-
MPIR_FUNC_ENTER;
387-
388-
MPID_THREAD_CS_ENTER(VCI, MPIDIU_THREAD_DYNPROC_MUTEX);
389-
mpi_errno = MPIDI_NM_upids_to_lpids(size, remote_upid_size, remote_upids, remote_lpids);
390-
MPIR_ERR_CHECK(mpi_errno);
391-
392-
fn_exit:
393-
MPID_THREAD_CS_EXIT(VCI, MPIDIU_THREAD_DYNPROC_MUTEX);
394-
MPIR_FUNC_EXIT;
395-
return mpi_errno;
396-
fn_fail:
397-
goto fn_exit;
398-
}
399-
400379
int MPIDIU_alloc_lut(MPIDI_rank_map_lut_t ** lut, int size)
401380
{
402381
int mpi_errno = MPI_SUCCESS;

src/mpid/ch4/src/ch4_proc.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ void MPIDIU_upidhash_add(const void *upid, int upid_len, int avtid, int lpid);
3333
MPIDI_upid_hash *MPIDIU_upidhash_find(const void *upid, int upid_len);
3434
void MPIDIU_upidhash_free(void);
3535
#endif
36-
int MPIDIU_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids,
37-
MPIR_Lpid * remote_lpids);
3836
int MPIDIU_alloc_lut(MPIDI_rank_map_lut_t ** lut, int size);
3937
int MPIDIU_release_lut(MPIDI_rank_map_lut_t * lut);
4038
int MPIDIU_alloc_mlut(MPIDI_rank_map_mlut_t ** mlut, int size);

src/mpid/ch4/src/ch4_spawn.c

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ int MPID_Open_port(MPIR_Info * info_ptr, char *port_name, int len)
235235
mpi_errno = get_port_name_tag(&tag);
236236
MPIR_ERR_CHECK(mpi_errno);
237237

238+
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0));
238239
mpi_errno = MPIDI_NM_get_local_upids(MPIR_Process.comm_self, &addrname_size, &addrname);
240+
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0));
239241
MPIR_ERR_CHECK(mpi_errno);
240242

241243
int err;
@@ -291,16 +293,25 @@ static int peer_intercomm_create(char *remote_addrname, int len, int tag,
291293
int mpi_errno = MPI_SUCCESS;
292294
int context_id, recvcontext_id;
293295
MPIR_Lpid remote_lpid;
296+
bool need_unlock = false;
294297

295298
mpi_errno = MPIR_Get_contextid_sparse(MPIR_Process.comm_self, &recvcontext_id, FALSE);
296299
MPIR_ERR_CHECK(mpi_errno);
297300

301+
/* We enter the LOCK to ensure the dynamic exchange don't get interleaved.
302+
* NOTE: most other functions enter lock at NM-layer except the the functions defined
303+
* in e.g. ofi_spawn.c and ucx_spawn.c. So only those functions are allowed
304+
* inside the CS.
305+
*/
306+
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0));
307+
need_unlock = true;
308+
298309
struct dynproc_conn_hdr hdr;
299310
if (is_sender) {
300311
/* insert remote address */
301312
int addrname_len = len;
302313
MPIR_Lpid *remote_lpids = &remote_lpid;
303-
mpi_errno = MPIDIU_upids_to_lpids(1, &addrname_len, remote_addrname, remote_lpids);
314+
mpi_errno = MPIDI_NM_upids_to_lpids(1, &addrname_len, remote_addrname, remote_lpids);
304315
MPIR_ERR_CHECK(mpi_errno);
305316

306317
/* fill hdr with context_id and addrname */
@@ -334,14 +345,16 @@ static int peer_intercomm_create(char *remote_addrname, int len, int tag,
334345
/* insert remote address */
335346
int addrname_len = hdr.addrname_len;
336347
MPIR_Lpid *remote_lpids = &remote_lpid;
337-
mpi_errno = MPIDIU_upids_to_lpids(1, &addrname_len, hdr.addrname, remote_lpids);
348+
mpi_errno = MPIDI_NM_upids_to_lpids(1, &addrname_len, hdr.addrname, remote_lpids);
338349
MPIR_ERR_CHECK(mpi_errno);
339350

340351
/* send remote context_id */
341352
hdr.context_id = recvcontext_id;
342353
mpi_errno = MPIDI_NM_dynamic_send(remote_lpid, tag, &hdr, sizeof(hdr.context_id), timeout);
343354
MPIR_ERR_CHECK(mpi_errno);
344355
}
356+
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0));
357+
need_unlock = false;
345358

346359
/* create peer intercomm */
347360
mpi_errno = MPIR_peer_intercomm_create(context_id, recvcontext_id,
@@ -354,6 +367,9 @@ static int peer_intercomm_create(char *remote_addrname, int len, int tag,
354367
if (recvcontext_id) {
355368
MPIR_Free_contextid(recvcontext_id);
356369
}
370+
if (need_unlock) {
371+
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0));
372+
}
357373
goto fn_exit;
358374
}
359375

0 commit comments

Comments
 (0)