Skip to content

Commit 1251d73

Browse files
committed
oshmem: scoll: fixes strided alltoall
Signed-off-by: Alex Mikheev <[email protected]> (cherry picked from commit cca67a6)
1 parent b6e825f commit 1251d73

File tree

1 file changed

+137
-69
lines changed

1 file changed

+137
-69
lines changed

oshmem/mca/scoll/basic/scoll_basic_alltoall.c

Lines changed: 137 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@
1919
#include "oshmem/mca/scoll/base/base.h"
2020
#include "scoll_basic.h"
2121

22-
static int _algorithm_simple(struct oshmem_group_t *group,
23-
void *target,
24-
const void *source,
25-
ptrdiff_t dst, ptrdiff_t sst,
26-
size_t nelems,
27-
size_t element_size,
28-
long *pSync);
22+
static int a2a_alg_simple(struct oshmem_group_t *group,
23+
void *target,
24+
const void *source,
25+
size_t nelems,
26+
size_t element_size);
27+
28+
static int a2as_alg_simple(struct oshmem_group_t *group,
29+
void *target,
30+
const void *source,
31+
ptrdiff_t dst, ptrdiff_t sst,
32+
size_t nelems,
33+
size_t element_size);
34+
2935

3036
int mca_scoll_basic_alltoall(struct oshmem_group_t *group,
3137
void *target,
@@ -36,88 +42,150 @@ int mca_scoll_basic_alltoall(struct oshmem_group_t *group,
3642
long *pSync,
3743
int alg)
3844
{
39-
int rc = OSHMEM_SUCCESS;
45+
int rc;
46+
int i;
4047

4148
/* Arguments validation */
4249
if (!group) {
4350
SCOLL_ERROR("Active set (group) of PE is not defined");
44-
rc = OSHMEM_ERR_BAD_PARAM;
51+
return OSHMEM_ERR_BAD_PARAM;
4552
}
4653

4754
/* Check if this PE is part of the group */
48-
if ((rc == OSHMEM_SUCCESS) && oshmem_proc_group_is_member(group)) {
49-
int i = 0;
50-
51-
if (pSync) {
52-
rc = _algorithm_simple(group,
53-
target,
54-
source,
55-
dst,
56-
sst,
57-
nelems,
58-
element_size,
59-
pSync);
60-
} else {
61-
SCOLL_ERROR("Incorrect argument pSync");
62-
rc = OSHMEM_ERR_BAD_PARAM;
63-
}
64-
65-
/* Restore initial values */
66-
SCOLL_VERBOSE(12,
67-
"PE#%d Restore special synchronization array",
68-
group->my_pe);
69-
for (i = 0; pSync && (i < _SHMEM_ALLTOALL_SYNC_SIZE); i++) {
70-
pSync[i] = _SHMEM_SYNC_VALUE;
71-
}
55+
if (!oshmem_proc_group_is_member(group)) {
56+
return OSHMEM_SUCCESS;
7257
}
7358

74-
return rc;
75-
}
59+
if (!pSync) {
60+
SCOLL_ERROR("Incorrect argument pSync");
61+
return OSHMEM_ERR_BAD_PARAM;
62+
}
7663

77-
static int _algorithm_simple(struct oshmem_group_t *group,
78-
void *target,
79-
const void *source,
80-
ptrdiff_t tst, ptrdiff_t sst,
81-
size_t nelems,
82-
size_t element_size,
83-
long *pSync)
84-
{
85-
int rc = OSHMEM_SUCCESS;
86-
int pe_cur;
87-
int i;
88-
int j;
89-
int k;
64+
if ((sst == 1) && (dst == 1)) {
65+
rc = a2a_alg_simple(group, target, source, nelems, element_size);
66+
} else {
67+
rc = a2as_alg_simple(group, target, source, dst, sst, nelems,
68+
element_size);
69+
}
9070

91-
SCOLL_VERBOSE(14,
92-
"[#%d] send data to all PE in the group",
93-
group->my_pe);
94-
j = oshmem_proc_group_find_id(group, group->my_pe);
95-
for (i = 0; i < group->proc_count; i++) {
96-
/* index permutation for better distribution of traffic */
97-
k = (((j)+(i))%(group->proc_count));
98-
pe_cur = oshmem_proc_pe(group->proc_array[k]);
99-
rc = MCA_SPML_CALL(put(
100-
(void *)((char *)target + j * tst * nelems * element_size),
101-
nelems * element_size,
102-
(void *)((char *)source + i * sst * nelems * element_size),
103-
pe_cur));
104-
if (OSHMEM_SUCCESS != rc) {
105-
break;
106-
}
71+
if (rc != OSHMEM_SUCCESS) {
72+
return rc;
10773
}
74+
10875
/* fence (which currently acts as quiet) is needed
10976
* because scoll level barrier does not guarantee put completion
11077
*/
11178
MCA_SPML_CALL(fence());
11279

11380
/* Wait for operation completion */
114-
if (rc == OSHMEM_SUCCESS) {
115-
SCOLL_VERBOSE(14, "[#%d] Wait for operation completion", group->my_pe);
116-
rc = BARRIER_FUNC(group,
117-
(pSync + 1),
118-
SCOLL_DEFAULT_ALG);
81+
SCOLL_VERBOSE(14, "[#%d] Wait for operation completion", group->my_pe);
82+
rc = BARRIER_FUNC(group, pSync + 1, SCOLL_DEFAULT_ALG);
83+
84+
/* Restore initial values */
85+
SCOLL_VERBOSE(12, "PE#%d Restore special synchronization array",
86+
group->my_pe);
87+
88+
for (i = 0; pSync && (i < _SHMEM_ALLTOALL_SYNC_SIZE); i++) {
89+
pSync[i] = _SHMEM_SYNC_VALUE;
11990
}
12091

12192
return rc;
12293
}
12394

95+
96+
static inline void *
97+
get_stride_elem(const void *base, ptrdiff_t sst, size_t nelems, size_t elem_size,
98+
int block_idx, int elem_idx)
99+
{
100+
/*
101+
* j th block starts at: nelems * element_size * sst * j
102+
* offset of the l th element in the block is: element_size * sst * l
103+
*/
104+
return (char *)base + elem_size * sst * (nelems * block_idx + elem_idx);
105+
}
106+
107+
static inline int
108+
get_dst_pe(struct oshmem_group_t *group, int src_blk_idx, int dst_blk_idx)
109+
{
110+
int dst_grp_pe;
111+
112+
/* index permutation for better distribution of traffic */
113+
dst_grp_pe = (dst_blk_idx + src_blk_idx) % group->proc_count;
114+
115+
/* convert to the global pe */
116+
return oshmem_proc_pe(group->proc_array[dst_grp_pe]);
117+
}
118+
119+
static int a2as_alg_simple(struct oshmem_group_t *group,
120+
void *target,
121+
const void *source,
122+
ptrdiff_t tst, ptrdiff_t sst,
123+
size_t nelems,
124+
size_t element_size)
125+
{
126+
int rc;
127+
int dst_pe;
128+
int src_blk_idx;
129+
int dst_blk_idx;
130+
size_t elem_idx;
131+
132+
SCOLL_VERBOSE(14,
133+
"[#%d] send data to all PE in the group",
134+
group->my_pe);
135+
136+
dst_blk_idx = oshmem_proc_group_find_id(group, group->my_pe);
137+
138+
for (src_blk_idx = 0; src_blk_idx < group->proc_count; src_blk_idx++) {
139+
140+
dst_pe = get_dst_pe(group, src_blk_idx, dst_blk_idx);
141+
for (elem_idx = 0; elem_idx < nelems; elem_idx++) {
142+
rc = MCA_SPML_CALL(put(
143+
get_stride_elem(target, tst, nelems, element_size,
144+
dst_blk_idx, elem_idx),
145+
element_size,
146+
get_stride_elem(source, sst, nelems, element_size,
147+
src_blk_idx, elem_idx),
148+
dst_pe));
149+
if (OSHMEM_SUCCESS != rc) {
150+
return rc;
151+
}
152+
}
153+
}
154+
return OSHMEM_SUCCESS;
155+
}
156+
157+
static int a2a_alg_simple(struct oshmem_group_t *group,
158+
void *target,
159+
const void *source,
160+
size_t nelems,
161+
size_t element_size)
162+
{
163+
int rc;
164+
int dst_pe;
165+
int src_blk_idx;
166+
int dst_blk_idx;
167+
void *dst_blk;
168+
169+
SCOLL_VERBOSE(14,
170+
"[#%d] send data to all PE in the group",
171+
group->my_pe);
172+
173+
dst_blk_idx = oshmem_proc_group_find_id(group, group->my_pe);
174+
175+
/* block start at stride 1 first elem */
176+
dst_blk = get_stride_elem(target, 1, nelems, element_size, dst_blk_idx, 0);
177+
178+
for (src_blk_idx = 0; src_blk_idx < group->proc_count; src_blk_idx++) {
179+
180+
dst_pe = get_dst_pe(group, src_blk_idx, dst_blk_idx);
181+
rc = MCA_SPML_CALL(put(dst_blk,
182+
nelems * element_size,
183+
get_stride_elem(source, 1, nelems,
184+
element_size, src_blk_idx, 0),
185+
dst_pe));
186+
if (OSHMEM_SUCCESS != rc) {
187+
return rc;
188+
}
189+
}
190+
return OSHMEM_SUCCESS;
191+
}

0 commit comments

Comments
 (0)