Skip to content

Commit 5e2c277

Browse files
committed
ch4/ofi: delay allocate addrs for multiple vci/nic
Let MPIDI_OFI_addr_t only contain field for root vci address, and only allocate more space for additional addresses when multiple vci and nic is enabled -- potentially at runtime. This avoids wasting memory for multiple vcis unless it is actually needed.
1 parent 477718a commit 5e2c277

File tree

5 files changed

+76
-41
lines changed

5 files changed

+76
-41
lines changed

src/mpid/ch4/netmod/ofi/init_addrxchg.c

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void)
133133

134134
for (int i = 0; i < num_nodes; i++) {
135135
MPIR_Assert(mapped_table[i] != FI_ADDR_NOTAVAIL);
136-
MPIDI_OFI_AV(&MPIDIU_get_av(0, node_roots[i])).dest[0][0] = mapped_table[i];
136+
MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, node_roots[i])) = mapped_table[i];
137137
}
138138
MPL_free(mapped_table);
139139
/* Then, allgather all address names using init_comm */
@@ -149,7 +149,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void)
149149
char *addrname = (char *) table + recv_bc_len * rank_map[i];
150150
MPIDI_OFI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[0].av,
151151
addrname, 1, &addr, 0ULL, NULL), avmap);
152-
MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[0][0] = addr;
152+
MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, i)) = addr;
153153
}
154154
}
155155
mpi_errno = MPIDU_bc_table_destroy();
@@ -163,7 +163,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void)
163163

164164
for (int i = 0; i < size; i++) {
165165
MPIR_Assert(mapped_table[i] != FI_ADDR_NOTAVAIL);
166-
MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[0][0] = mapped_table[i];
166+
MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, i)) = mapped_table[i];
167167
}
168168
MPL_free(mapped_table);
169169
mpi_errno = MPIDU_bc_table_destroy();
@@ -173,8 +173,8 @@ int MPIDI_OFI_addr_exchange_root_ctx(void)
173173
/* check */
174174
if (MPIDI_OFI_ENABLE_AV_TABLE) {
175175
for (int r = 0; r < size; r++) {
176-
MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, r));
177-
MPIR_Assert(av->dest[0][0] == get_root_av_table_index(r));
176+
MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, r);
177+
MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == get_root_av_table_index(r));
178178
}
179179
}
180180

@@ -192,7 +192,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void)
192192
/* Macros to reduce clutter, so we can focus on the ordering logics.
193193
* Note: they are not perfectly wrapped, but tolerable since only used here. */
194194
#define GET_AV_AND_ADDRNAMES(rank) \
195-
MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, rank)); \
195+
MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, rank); \
196196
char *r_names = all_names + rank * max_vcis * num_nics * name_len;
197197

198198
#define DO_AV_INSERT(ctx_idx, nic, vci) \
@@ -244,6 +244,14 @@ int MPIDI_OFI_addr_exchange_all_ctx(void)
244244
goto fn_exit;
245245
}
246246

247+
/* allocate additional av addrs */
248+
for (int i = 0; i < size; i++) {
249+
MPIDI_av_entry_t *av = &MPIDIU_get_av(0, i);
250+
MPIDI_OFI_AV(av).all_dest = MPL_malloc(max_vcis * num_nics * sizeof(fi_addr_t),
251+
MPL_MEM_ADDRESS);
252+
MPIR_ERR_CHKANDJUMP(!MPIDI_OFI_AV(av).all_dest, mpi_errno, MPI_ERR_OTHER, "**nomem");
253+
}
254+
247255
/* libfabric uses uniform name_len within a single provider */
248256
int name_len = MPIDI_OFI_global.addrnamelen;
249257
int my_len = max_vcis * num_nics * name_len;
@@ -274,7 +282,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void)
274282
for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) {
275283
SKIP_ROOT(nic, vci);
276284
DO_AV_INSERT(root_ctx_idx, nic, vci);
277-
av->dest[nic][vci] = addr;
285+
MPIDI_OFI_AV_ADDR(av, 0, 0, vci, nic) = addr;
278286
}
279287
}
280288
}
@@ -304,23 +312,23 @@ int MPIDI_OFI_addr_exchange_all_ctx(void)
304312
if (is_node_roots[r]) {
305313
GET_AV_AND_ADDRNAMES(r);
306314
DO_AV_INSERT(ctx_idx, 0, 0);
307-
MPIR_Assert(av->dest[0][0] == addr);
315+
MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr);
308316
}
309317
}
310318
/* non-node-root */
311319
for (int r = 0; r < size; r++) {
312320
if (!is_node_roots[r]) {
313321
GET_AV_AND_ADDRNAMES(r);
314322
DO_AV_INSERT(ctx_idx, 0, 0);
315-
MPIR_Assert(av->dest[0][0] == addr);
323+
MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr);
316324
}
317325
}
318326
} else {
319327
/* !MPIR_CVAR_CH4_ROOTS_ONLY_PMI */
320328
for (int r = 0; r < size; r++) {
321329
GET_AV_AND_ADDRNAMES(r);
322330
DO_AV_INSERT(ctx_idx, 0, 0);
323-
MPIR_Assert(av->dest[0][0] == addr);
331+
MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr);
324332
}
325333
}
326334

@@ -331,7 +339,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void)
331339
for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) {
332340
SKIP_ROOT(nic, vci);
333341
DO_AV_INSERT(ctx_idx, nic, vci);
334-
MPIR_Assert(av->dest[nic][vci] == addr);
342+
MPIR_Assert(MPIDI_OFI_AV_ADDR(av, 0, 0, vci, nic) == addr);
335343
}
336344
}
337345
}
@@ -344,11 +352,11 @@ int MPIDI_OFI_addr_exchange_all_ctx(void)
344352
#if MPIDI_CH4_MAX_VCIS > 1
345353
if (MPIDI_OFI_ENABLE_AV_TABLE) {
346354
for (int r = 0; r < size; r++) {
347-
MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, r));
355+
MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, r);
348356
for (int nic = 0; nic < num_nics; nic++) {
349357
for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) {
350-
MPIR_Assert(av->dest[nic][vci] == get_av_table_index(r, nic, vci,
351-
all_num_vcis));
358+
MPIR_Assert(MPIDI_OFI_AV_ADDR(av, 0, 0, vci, nic) ==
359+
get_av_table_index(r, nic, vci, all_num_vcis));
352360
}
353361
}
354362
}

src/mpid/ch4/netmod/ofi/ofi_impl.h

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,29 @@ ATTRIBUTE((unused));
3333
#define MPIDI_OFI_COMM(comm) ((comm)->dev.ch4.netmod.ofi)
3434
#define MPIDI_OFI_COMM_TO_INDEX(comm,rank) \
3535
MPIDIU_comm_rank_to_pid(comm, rank, NULL, NULL)
36-
#define MPIDI_OFI_TO_PHYS(avtid, lpid, _nic) \
37-
MPIDI_OFI_AV(&MPIDIU_get_av((avtid), (lpid))).dest[_nic][0]
36+
37+
#define MPIDI_OFI_AV_ROOT_ADDR(av) MPIDI_OFI_AV(av).root_dest
38+
39+
/* NOTE: these macros are a mess to read. They will be cleaned up in a few commits. */
40+
#ifdef MPIDI_OFI_VNI_USE_DOMAIN
41+
#define MPIDI_OFI_AV_ADDR_ROOT(av) \
42+
MPIDI_OFI_AV(av).root_dest
43+
#define MPIDI_OFI_AV_ADDR_OFFSET(av, vci, nic) \
44+
(MPIDI_OFI_AV(av).all_dest[(vci)*MPIDI_OFI_global.num_nics+(nic)] + MPIDI_OFI_AV(av).root_offset)
45+
#define MPIDI_OFI_AV_ADDR_NO_OFFSET(av, vci, nic) \
46+
MPIDI_OFI_AV(av).all_dest[(vci)*MPIDI_OFI_global.num_nics+(nic)]
47+
#else /* scalable endpoints - all vci share the same addr */
48+
#define MPIDI_OFI_AV_ADDR_ROOT(av) \
49+
MPIDI_OFI_AV(av).root_dest
50+
#define MPIDI_OFI_AV_ADDR_OFFSET(av, vci, nic) \
51+
(MPIDI_OFI_AV(av).all_dest[nic] + MPIDI_OFI_AV(av).root_offset)
52+
#define MPIDI_OFI_AV_ADDR_NO_OFFSET(av, vci, nic) \
53+
MPIDI_OFI_AV(av).all_dest[nic]
54+
#endif
55+
#define MPIDI_OFI_AV_ADDR(av, local_vci, local_nic, vci, nic) \
56+
((local_vci==0 && local_nic==0) ? \
57+
((vci == 0 && nic == 0) ? MPIDI_OFI_AV_ADDR_ROOT(av) : MPIDI_OFI_AV_ADDR_OFFSET(av, vci, nic)) : \
58+
MPIDI_OFI_AV_ADDR_NO_OFFSET(av, vci, nic))
3859

3960
#define MPIDI_OFI_WIN(win) ((win)->dev.netmod.ofi)
4061

@@ -444,18 +465,19 @@ MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_av_to_phys(MPIDI_av_entry_t * av,
444465
int local_vci, int local_nic,
445466
int vci, int nic)
446467
{
468+
fi_addr_t dest = MPIDI_OFI_AV_ADDR(av, local_vci, local_nic, vci, nic);
447469
#ifdef MPIDI_OFI_VNI_USE_DOMAIN
448470
if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) {
449-
return fi_rx_addr(MPIDI_OFI_AV(av).dest[nic][vci], 0, MPIDI_OFI_MAX_ENDPOINTS_BITS);
471+
return fi_rx_addr(dest, 0, MPIDI_OFI_MAX_ENDPOINTS_BITS);
450472
} else {
451-
return MPIDI_OFI_AV(av).dest[nic][vci];
473+
return dest;
452474
}
453475
#else /* MPIDI_OFI_VNI_USE_SEPCTX */
454476
if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) {
455-
return fi_rx_addr(MPIDI_OFI_AV(av).dest[nic][0], vci, MPIDI_OFI_MAX_ENDPOINTS_BITS);
477+
return fi_rx_addr(dest, vci, MPIDI_OFI_MAX_ENDPOINTS_BITS);
456478
} else {
457479
MPIR_Assert(vci == 0);
458-
return MPIDI_OFI_AV(av).dest[nic][0];
480+
return dest;
459481
}
460482
#endif
461483
}

src/mpid/ch4/netmod/ofi/ofi_init.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,13 @@ int MPIDI_OFI_mpi_finalize_hook(void)
11121112
fi_freeinfo(MPIDI_OFI_global.prov_use[i]);
11131113
}
11141114

1115+
/* free av entries for multiple vcis and nics */
1116+
for (i = 0; i < MPIR_Process.size; i++) {
1117+
MPIDI_av_entry_t *av = &MPIDIU_get_av(0, i);
1118+
MPL_free(MPIDI_OFI_AV(av).all_dest);
1119+
MPIDI_OFI_AV(av).all_dest = NULL;
1120+
}
1121+
11151122
MPIDIU_map_destroy(MPIDI_OFI_global.win_map);
11161123

11171124
if (MPIDI_OFI_ENABLE_AM) {
@@ -1182,7 +1189,7 @@ static int create_sep_tx(struct fid_ep *ep, int idx, struct fid_ep **p_tx,
11821189
struct fid_cq *cq, struct fid_cntr *cntr, int nic);
11831190
static int create_sep_rx(struct fid_ep *ep, int idx, struct fid_ep **p_rx, struct fid_cq *cq,
11841191
int nic);
1185-
static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av, int nic);
1192+
static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av);
11861193
static int open_local_av(struct fid_domain *p_domain, struct fid_av **p_av);
11871194

11881195
/* This function creates a vci context which includes all of the OFI-level objects needed to
@@ -1416,7 +1423,7 @@ static int create_vci_domain(struct fid_domain **p_domain, struct fid_av **p_av,
14161423
* Otherwise, set MPIDI_OFI_global.got_named_av and
14171424
* copy the map_addr.
14181425
*/
1419-
if (MPIR_CVAR_CH4_OFI_ENABLE_SHARED_AV && try_open_shared_av(domain, p_av, nic)) {
1426+
if (MPIR_CVAR_CH4_OFI_ENABLE_SHARED_AV && nic == 0 && try_open_shared_av(domain, p_av)) {
14201427
MPIDI_OFI_global.got_named_av = 1;
14211428
} else {
14221429
mpi_errno = open_local_av(domain, p_av);
@@ -1521,14 +1528,10 @@ static int create_sep_rx(struct fid_ep *ep, int idx, struct fid_ep **p_rx, struc
15211528
goto fn_exit;
15221529
}
15231530

1524-
static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av, int nic)
1531+
static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av)
15251532
{
15261533
int ret = 0;
15271534

1528-
/* It's not possible to use shared address vectors with more than one domain in a single
1529-
* process. If we're trying to do that (for example if we are using MPIDI_OFI_VNI_USE_DOMAIN or
1530-
* we have multiple VNIs because of multi-nic), attempt to open up the shared AV in one VNI and
1531-
* then copy the results to the others later. */
15321535
struct fi_av_attr av_attr;
15331536
memset(&av_attr, 0, sizeof(av_attr));
15341537
if (MPIDI_OFI_ENABLE_AV_TABLE) {
@@ -1551,7 +1554,7 @@ static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av, i
15511554
/* directly references the mapped fi_addr_t array instead */
15521555
fi_addr_t *mapped_table = (fi_addr_t *) av_attr.map_addr;
15531556
for (int i = 0; i < MPIR_Process.size; i++) {
1554-
MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[nic][0] = mapped_table[i];
1557+
MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, i)) = mapped_table[i];
15551558
MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_MAP, VERBOSE,
15561559
(MPL_DBG_FDEST, " grank mapped to: rank=%d, av=%p, dest=%" PRIu64,
15571560
i, (void *) &MPIDIU_get_av(0, i), mapped_table[i]));

src/mpid/ch4/netmod/ofi/ofi_pre.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,14 @@ typedef struct {
310310
/* Maximum number of network interfaces CH4 can support. */
311311
#define MPIDI_OFI_MAX_NICS 8
312312

313+
/* Imagine a dimension of [local_vci][local_nic][rank][vci][nic] -
314+
* all local endpoints will share the same remote address due to the same insertion order
315+
* and use of FI_AV_TABLE except the local root endpoint.
316+
*/
313317
typedef struct {
314-
#ifdef MPIDI_OFI_VNI_USE_DOMAIN
315-
fi_addr_t dest[MPIDI_OFI_MAX_NICS][MPIDI_CH4_MAX_VCIS]; /* [nic][vci] */
316-
#else
317-
fi_addr_t dest[MPIDI_OFI_MAX_NICS][1];
318-
#endif
318+
fi_addr_t root_dest; /* [0][0][r][0][0] */
319+
fi_addr_t root_offset; /* [0][0][r][vci][nic] - [*][*][r][vci][nic] */
320+
fi_addr_t *all_dest; /* [*][*][r][vci][nic] */
319321
} MPIDI_OFI_addr_t;
320322

321323
#endif /* OFI_PRE_H_INCLUDED */

src/mpid/ch4/netmod/ofi/ofi_spawn.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ int MPIDI_OFI_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids
143143
int n_new_procs = 0;
144144
int n_avts;
145145
char *curr_upid;
146-
int nic = 0;
147-
int ctx_idx = MPIDI_OFI_get_ctx_index(0, nic);
146+
int ctx_idx = MPIDI_OFI_get_ctx_index(0, 0);
148147

149148
MPIR_CHKLMEM_DECL();
150149

@@ -171,8 +170,9 @@ int MPIDI_OFI_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids
171170
}
172171
for (j = 0; j < MPIDIU_get_av_table(k)->size; j++) {
173172
sz = MPIDI_OFI_global.addrnamelen;
173+
MPIDI_av_entry_t *av = &MPIDIU_get_av(k, j);
174174
MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av,
175-
MPIDI_OFI_TO_PHYS(k, j, nic), &tbladdr, &sz), 0,
175+
MPIDI_OFI_AV_ROOT_ADDR(av), &tbladdr, &sz), 0,
176176
avlookup);
177177
if (sz == addrname_len && !memcmp(tbladdr, addrname, addrname_len)) {
178178
remote_lpids[i] = MPIDIU_GPID_CREATE(k, j);
@@ -207,7 +207,7 @@ int MPIDI_OFI_upids_to_lpids(int size, int *remote_upid_size, char *remote_upids
207207
MPIDI_OFI_VCI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[ctx_idx].av, addrname,
208208
1, &addr, 0ULL, NULL), 0, avmap);
209209
MPIR_Assert(addr != FI_ADDR_NOTAVAIL);
210-
MPIDI_OFI_AV(&MPIDIU_get_av(avtid, i)).dest[nic][0] = addr;
210+
MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(avtid, i)) = addr;
211211

212212
int node_id;
213213
mpi_errno = MPIR_nodeid_lookup(hostname, &node_id);
@@ -230,8 +230,7 @@ int MPIDI_OFI_get_local_upids(MPIR_Comm * comm, int **local_upid_size, char **lo
230230
int mpi_errno = MPI_SUCCESS;
231231
int i;
232232
char *temp_buf = NULL;
233-
int nic = 0;
234-
int ctx_idx = MPIDI_OFI_get_ctx_index(0, nic);
233+
int ctx_idx = MPIDI_OFI_get_ctx_index(0, 0);
235234

236235
MPIR_CHKPMEM_DECL();
237236

@@ -260,8 +259,9 @@ int MPIDI_OFI_get_local_upids(MPIR_Comm * comm, int **local_upid_size, char **lo
260259
idx += hostname_len + 1;
261260

262261
size_t sz = MPIDI_OFI_global.addrnamelen;;
263-
MPIDI_OFI_addr_t *av = &MPIDI_OFI_AV(MPIDIU_comm_rank_to_av(comm, i));
264-
MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av, av->dest[nic][0],
262+
MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, i);
263+
MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av,
264+
MPIDI_OFI_AV_ROOT_ADDR(av),
265265
temp_buf + idx, &sz), 0, avlookup);
266266
idx += (int) sz;
267267

0 commit comments

Comments
 (0)