Skip to content

Commit 14d4ada

Browse files
committed
ch4/shm: refactor iqueue shmem allocation
Consolidate the shmem allocations in iqueue to 2 slabs. One root slab that is initialized at world_init. The other all_slab for per-vci transport, initialized at the time of init vcis. The goal is to eventually allow more flexible shm creation, potentially allow init within a non-world communicator.
1 parent 4dea369 commit 14d4ada

File tree

4 files changed

+83
-30
lines changed

4 files changed

+83
-30
lines changed

src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
MPIDI_POSIX_eager_iqueue_global_t MPIDI_POSIX_eager_iqueue_global;
3838

39-
static int init_transport(int vci_src, int vci_dst)
39+
static int init_transport(void *slab, int vci_src, int vci_dst)
4040
{
4141
int mpi_errno = MPI_SUCCESS;
4242

@@ -51,28 +51,24 @@ static int init_transport(int vci_src, int vci_dst)
5151
MPIDU_GENQ_SHMEM_QUEUE_TYPE__MPSC,
5252
MPIDU_GENQ_SHMEM_QUEUE_TYPE__MPMC
5353
};
54-
mpi_errno = MPIDU_genq_shmem_pool_create(transport->size_of_cell, transport->num_cells,
54+
mpi_errno = MPIDU_genq_shmem_pool_create(slab, MPIDI_POSIX_eager_iqueue_global.slab_size,
55+
transport->size_of_cell, transport->num_cells,
5556
MPIR_Process.local_size,
5657
MPIR_Process.local_rank,
5758
2, queue_types, &transport->cell_pool);
5859
MPIR_ERR_CHECK(mpi_errno);
5960
} else {
6061
int queue_type = MPIDU_GENQ_SHMEM_QUEUE_TYPE__MPSC;
61-
mpi_errno = MPIDU_genq_shmem_pool_create(transport->size_of_cell, transport->num_cells,
62+
mpi_errno = MPIDU_genq_shmem_pool_create(slab, MPIDI_POSIX_eager_iqueue_global.slab_size,
63+
transport->size_of_cell, transport->num_cells,
6264
MPIR_Process.local_size,
6365
MPIR_Process.local_rank,
6466
1, &queue_type, &transport->cell_pool);
6567
MPIR_ERR_CHECK(mpi_errno);
6668
}
6769

68-
size_t size_of_terminals;
69-
/* Create one terminal for each process with which we will be able to communicate. */
70-
size_of_terminals = (size_t) MPIR_Process.local_size * sizeof(MPIDU_genq_shmem_queue_u);
71-
72-
/* Create the shared memory regions that will be used for the iqueue cells and terminals. */
73-
mpi_errno = MPIDU_Init_shm_alloc(size_of_terminals, (void *) &transport->terminals);
74-
MPIR_ERR_CHECK(mpi_errno);
75-
70+
transport->terminals = (void *) ((char *) slab +
71+
MPIDI_POSIX_eager_iqueue_global.terminal_offset);
7672
transport->my_terminal = &transport->terminals[MPIR_Process.local_rank];
7773

7874
mpi_errno = MPIDU_genq_shmem_queue_init(transport->my_terminal,
@@ -98,7 +94,27 @@ int MPIDI_POSIX_iqueue_init(int rank, int size)
9894
/* Init vci 0. Communication on vci 0 is enabled afterwards. */
9995
MPIDI_POSIX_eager_iqueue_global.max_vcis = 1;
10096

101-
mpi_errno = init_transport(0, 0);
97+
/* calculate needed shmem size per (vci_src, vci_dst) */
98+
int num_free_queue = MPIR_CVAR_CH4_SHM_POSIX_TOPO_ENABLE ? 2 : 1;
99+
int cell_size = MPIR_CVAR_CH4_SHM_POSIX_IQUEUE_CELL_SIZE;
100+
int num_cells = MPIR_CVAR_CH4_SHM_POSIX_IQUEUE_NUM_CELLS;
101+
int nprocs = MPIR_Process.local_size;
102+
103+
int pool_size = MPIDU_genq_shmem_pool_size(cell_size, num_cells, nprocs, num_free_queue);
104+
int terminal_size = num_proc * sizeof(MPIDU_genq_shmem_queue_u);
105+
106+
int slab_size = pool_size + terminal_size;
107+
108+
/* Create the shared memory regions that will be used for the iqueue cells and terminals. */
109+
void *slab;
110+
mpi_errno = MPIDU_Init_shm_alloc(slab_size, (void *) &slab);
111+
MPIR_ERR_CHECK(mpi_errno);
112+
113+
MPIDI_POSIX_eager_iqueue_global.slab_size = slab_size;
114+
MPIDI_POSIX_eager_iqueue_global.terminal_offset = pool_size;
115+
MPIDI_POSIX_eager_iqueue_global.root_slab = slab;
116+
117+
mpi_errno = init_transport(slab, 0, 0);
102118
MPIR_ERR_CHECK(mpi_errno);
103119

104120
mpi_errno = MPIDU_Init_shm_barrier();
@@ -127,15 +143,24 @@ int MPIDI_POSIX_iqueue_post_init(void)
127143
}
128144

129145
MPIDI_POSIX_eager_iqueue_global.max_vcis = max_vcis;
146+
int slab_size = MPIDI_POSIX_eager_iqueue_global.slab_size * max_vcis * max_vcis;
147+
/* Create the shared memory regions for all vcis */
148+
/* TODO: do shm alloc in a comm */
149+
void *slab;
150+
mpi_errno = MPIDU_Init_shm_alloc(slab_size, (void *) &slab);
151+
MPIR_ERR_CHECK(mpi_errno);
152+
153+
MPIDI_POSIX_eager_iqueue_global.all_slab = slab;
130154

131155
for (int vci_src = 0; vci_src < max_vcis; vci_src++) {
132156
for (int vci_dst = 0; vci_dst < max_vcis; vci_dst++) {
133157
if (vci_src == 0 && vci_dst == 0) {
134158
continue;
135159
}
136-
mpi_errno = init_transport(vci_src, vci_dst);
160+
void *p = (char *) slab + (vci_src * max_vcis + vci_dst) *
161+
MPIDI_POSIX_eager_iqueue_global.slab_size;
162+
mpi_errno = init_transport(p, vci_src, vci_dst);
137163
MPIR_ERR_CHECK(mpi_errno);
138-
139164
}
140165
}
141166

@@ -156,18 +181,34 @@ int MPIDI_POSIX_iqueue_finalize(void)
156181

157182
MPIR_FUNC_ENTER;
158183

184+
if (MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab) {
185+
MPIDI_POSIX_eager_iqueue_transport_t *transport;
186+
transport = MPIDI_POSIX_eager_iqueue_get_transport(vci_src, vci_dst);
187+
188+
mpi_errno = MPIDU_genq_shmem_pool_destroy(transport->cell_pool);
189+
MPIR_ERR_CHECK(mpi_errno);
190+
191+
mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab);
192+
MPIR_ERR_CHECK(mpi_errno);
193+
MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab = NULL;
194+
}
195+
196+
if (!MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab) {
197+
goto fn_exit;
198+
}
159199
int max_vcis = MPIDI_POSIX_eager_iqueue_global.max_vcis;
160200
for (int vci_src = 0; vci_src < max_vcis; vci_src++) {
161201
for (int vci_dst = 0; vci_dst < max_vcis; vci_dst++) {
162202
MPIDI_POSIX_eager_iqueue_transport_t *transport;
163203
transport = MPIDI_POSIX_eager_iqueue_get_transport(vci_src, vci_dst);
164204

165-
mpi_errno = MPIDU_Init_shm_free(transport->terminals);
166-
MPIR_ERR_CHECK(mpi_errno);
167205
mpi_errno = MPIDU_genq_shmem_pool_destroy(transport->cell_pool);
168206
MPIR_ERR_CHECK(mpi_errno);
169207
}
170208
}
209+
mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab);
210+
MPIR_ERR_CHECK(mpi_errno);
211+
MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab = NULL;
171212

172213
fn_exit:
173214
MPIR_FUNC_EXIT;

src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ typedef struct MPIDI_POSIX_eager_iqueue_transport {
3535

3636
typedef struct MPIDI_POSIX_eager_iqueue_global {
3737
int max_vcis;
38+
/* sizes for shmem slabs */
39+
int slab_size;
40+
int terminal_offset;
41+
/* shmem slabs */
42+
void *root_slab;
43+
void *all_slab;
3844
/* 2d array indexed with [src_vci][dst_vci] */
3945
MPIDI_POSIX_eager_iqueue_transport_t transports[MPIDI_CH4_MAX_VCIS][MPIDI_CH4_MAX_VCIS];
4046
} MPIDI_POSIX_eager_iqueue_global_t;

src/mpid/common/genq/mpidu_genq_shmem_pool.c

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,23 @@ static int cell_block_alloc(MPIDU_genqi_shmem_pool_s * pool, int rank)
9595
goto fn_exit;
9696
}
9797

98-
int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_queue,
99-
uintptr_t num_proc, int rank, uintptr_t num_free_queue,
98+
int MPIDU_genq_shmem_pool_size(int cell_size, int cells_per_free_queue,
99+
int num_proc, int num_free_queue)
100+
{
101+
int aligned_cell_size = RESIZE_TO_MAX_ALIGN(cell_size);
102+
int cell_alloc_size = sizeof(MPIDU_genqi_shmem_cell_header_s) + aligned_cell_size;
103+
int total_cells_size = num_proc * num_free_queue * cells_per_free_queue * cell_alloc_size;
104+
int free_queue_size = num_proc * num_free_queue * sizeof(MPIDU_genq_shmem_queue_u);
105+
return total_cells_size + free_queue_size;
106+
}
107+
108+
int MPIDU_genq_shmem_pool_create(void *slab, int slab_size,
109+
int cell_size, int cells_per_free_queue,
110+
int num_proc, int rank, int num_free_queue,
100111
int *queue_types, MPIDU_genq_shmem_pool_t * pool)
101112
{
102113
int rc = MPI_SUCCESS;
103114
MPIDU_genqi_shmem_pool_s *pool_obj;
104-
uintptr_t slab_size = 0;
105115
uintptr_t aligned_cell_size = 0;
106116

107117
MPIR_FUNC_ENTER;
@@ -117,15 +127,13 @@ int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_q
117127
pool_obj->num_free_queue = num_free_queue;
118128
pool_obj->rank = rank;
119129
pool_obj->gpu_registered = false;
130+
pool_obj->slab = slab;
120131

121132
/* the global_block_index is at the end of the slab to avoid extra need of alignment */
122133
int total_cells_size = num_proc * num_free_queue * cells_per_free_queue
123134
* pool_obj->cell_alloc_size;
124135
int free_queue_size = num_proc * num_free_queue * sizeof(MPIDU_genq_shmem_queue_u);
125-
slab_size = total_cells_size + free_queue_size;
126-
127-
rc = MPIDU_Init_shm_alloc(slab_size, &pool_obj->slab);
128-
MPIR_ERR_CHECK(rc);
136+
MPIR_Assertp(slab_size >= total_cells_size + free_queue_size);
129137

130138
pool_obj->cell_header_base = (MPIDU_genqi_shmem_cell_header_s *) pool_obj->slab;
131139
pool_obj->free_queues =
@@ -140,16 +148,12 @@ int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_q
140148
rc = cell_block_alloc(pool_obj, rank);
141149
MPIR_ERR_CHECK(rc);
142150

143-
rc = MPIDU_Init_shm_barrier();
144-
MPIR_ERR_CHECK(rc);
145-
146151
*pool = (MPIDU_genq_shmem_pool_t) pool_obj;
147152

148153
fn_exit:
149154
MPIR_FUNC_EXIT;
150155
return rc;
151156
fn_fail:
152-
MPIDU_Init_shm_free(pool_obj->slab);
153157
MPL_free(pool_obj);
154158
goto fn_exit;
155159
}
@@ -166,7 +170,6 @@ int MPIDU_genq_shmem_pool_destroy(MPIDU_genq_shmem_pool_t pool)
166170
if (pool_obj->gpu_registered) {
167171
MPIR_gpu_unregister_host(pool_obj->slab);
168172
}
169-
MPIDU_Init_shm_free(pool_obj->slab);
170173

171174
/* free self */
172175
MPL_free(pool_obj);

src/mpid/common/genq/mpidu_genq_shmem_pool.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
#include <stdint.h>
1515
#include <string.h>
1616

17-
int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_queue,
18-
uintptr_t num_proc, int rank, uintptr_t num_free_queue,
17+
int MPIDU_genq_shmem_pool_size(int cell_size, int cells_per_free_queue,
18+
int num_proc, int num_free_queue);
19+
int MPIDU_genq_shmem_pool_create(void *slab, int slab_size,
20+
int cell_size, int cells_per_free_queue,
21+
int num_proc, int rank, int num_free_queue,
1922
int *queue_types, MPIDU_genq_shmem_pool_t * pool);
2023
int MPIDU_genq_shmem_pool_destroy(MPIDU_genq_shmem_pool_t pool);
2124
int MPIDU_genqi_shmem_pool_register(MPIDU_genqi_shmem_pool_s * pool_obj);

0 commit comments

Comments
 (0)