@@ -85,7 +85,9 @@ using std::vector;
85
85
* messages and exchanged in parallel / asynchronously. Note
86
86
* that 32 is a hard upper limit (before MPI length primitives
87
87
* overflow), 29 (at double precision) seems to be the upper-limit
88
- * for UCX exchange, so 28 (to support quad precision) seems safe
88
+ * for UCX exchange, so 28 (to support quad precision) seems safe.
89
+ * While 2^28 fits in 'int', we use 'qindex' so that arithmetic
90
+ * never overflows, and we cast down to 'int' when safe
89
91
*/
90
92
91
93
qindex MAX_MESSAGE_LENGTH = powerOf2(28 );
@@ -149,7 +151,27 @@ qindex MAX_MESSAGE_LENGTH = powerOf2(28);
149
151
*/
150
152
151
153
152
- int NULL_TAG = 0 ;
154
+ int getMaxNumMessages () {
155
+ #if COMPILE_MPI
156
+
157
+ // the max supported tag value constrains the total number of messages
158
+ // we can send in a round of communication, since we uniquely tag
159
+ // each message in a round such that we do not rely upon message-order
160
+ // gaurantees and ergo can safely support UCX adaptive routing (AR)
161
+ int maxNumMsgs, isAttribSet;
162
+
163
+ MPI_Comm_get_attr (MPI_COMM_WORLD, MPI_TAG_UB, &maxNumMsgs, &isAttribSet);
164
+
165
+ if (!isAttribSet)
166
+ error_commTagUpperBoundNotSet ();
167
+
168
+ return maxNumMsgs;
169
+
170
+ #else
171
+ error_commButEnvNotDistributed ();
172
+ return -1 ;
173
+ #endif
174
+ }
153
175
154
176
155
177
std::array<qindex,2 > dividePow2PayloadIntoMessages (qindex numAmps) {
@@ -159,8 +181,15 @@ std::array<qindex,2> dividePow2PayloadIntoMessages(qindex numAmps) {
159
181
if (numAmps < MAX_MESSAGE_LENGTH)
160
182
return {numAmps, 1 };
161
183
162
- // else, payload divides evenly between max-size messages
163
- qindex numMessages = numAmps / MAX_MESSAGE_LENGTH;
184
+ // else, payload divides evenly between max-size messages (always fits in int)
185
+ qindex numMessages = numAmps / MAX_MESSAGE_LENGTH;
186
+
187
+ // which we must be able to uniquely tag
188
+ if (numMessages > getMaxNumMessages ())
189
+ error_commNumMessagesExceedTagMax ();
190
+
191
+ // outputs always fit in 'int' but we force them to be 'qindex' since
192
+ // caller will multiply them with ints and could easily overflow
164
193
return {MAX_MESSAGE_LENGTH, numMessages};
165
194
}
166
195
@@ -172,8 +201,15 @@ std::array<qindex,3> dividePayloadIntoMessages(qindex numAmps) {
172
201
return {numAmps, 1 , 0 };
173
202
174
203
// else, use as many max-size messages as possible, and one smaller msg
175
- qindex numMaxSizeMsgs = numAmps / MAX_MESSAGE_LENGTH; // floors
204
+ qindex numMaxSizeMsgs = numAmps / MAX_MESSAGE_LENGTH; // floors
176
205
qindex remainingMsgSize = numAmps - numMaxSizeMsgs * MAX_MESSAGE_LENGTH;
206
+
207
+ // all of which we must be able to uniquely tag
208
+ if (numMaxSizeMsgs + 1 > getMaxNumMessages ())
209
+ error_commNumMessagesExceedTagMax ();
210
+
211
+ // outputs always fit in 'int' but we force them to be 'qindex' since
212
+ // caller will multiply them with ints and could easily overflow
177
213
return {MAX_MESSAGE_LENGTH, numMaxSizeMsgs, remainingMsgSize};
178
214
}
179
215
@@ -193,11 +229,12 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) {
193
229
auto [messageSize, numMessages] = dividePow2PayloadIntoMessages (numElems);
194
230
vector<MPI_Request> requests (2 *numMessages, MPI_REQUEST_NULL);
195
231
196
- // asynchronously exchange the messages (effecting MPI_Isendrecv), exploiting orderedness gaurantee.
197
- // note the exploitation of orderedness means we cannot use UCX's adaptive-routing (AR).
232
+ // asynchronously exchange the messages (effecting MPI_Isendrecv), using unique tags
233
+ // so that messages are permitted to arrive out-of-order (supporting UCX adaptive-routing)
198
234
for (qindex m=0 ; m<numMessages; m++) {
199
- MPI_Isend (&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &requests[2 *m]);
200
- MPI_Irecv (&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &requests[2 *m+1 ]);
235
+ int tag = static_cast <int >(m); // gauranteed int, but m*messageSize needs qindex
236
+ MPI_Isend (&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2 *m]);
237
+ MPI_Irecv (&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2 *m+1 ]);
201
238
}
202
239
203
240
// wait for all exchanges to complete (MPI will automatically free the request memory)
@@ -224,9 +261,11 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) {
224
261
// divide the data into multiple messages
225
262
auto [messageSize, numMessages] = dividePow2PayloadIntoMessages (numElems);
226
263
227
- // asynchronously send the messages; pairRank receives the same ordering
228
- for (qindex m=0 ; m<numMessages; m++)
229
- MPI_Isend (&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &nullReq);
264
+ // asynchronously send the uniquely-tagged messages
265
+ for (qindex m=0 ; m<numMessages; m++) {
266
+ int tag = static_cast <int >(m); // gauranteed int, but m*messageSize needs qindex
267
+ MPI_Isend (&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &nullReq);
268
+ }
230
269
231
270
#else
232
271
error_commButEnvNotDistributed ();
@@ -243,9 +282,11 @@ void receiveArray(qcomp* dest, qindex numElems, int pairRank) {
243
282
// create a request for each asynch receive below
244
283
vector<MPI_Request> requests (numMessages, MPI_REQUEST_NULL);
245
284
246
- // listen to receive each message asynchronously (as per arxiv.org/abs/2308.07402)
247
- for (qindex m=0 ; m<numMessages; m++)
248
- MPI_Irecv (&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &requests[m]);
285
+ // listen to receive each uniquely-tagged message asynchronously (as per arxiv.org/abs/2308.07402)
286
+ for (qindex m=0 ; m<numMessages; m++) {
287
+ int tag = static_cast <int >(m); // gauranteed int, but m*messageSize needs qindex
288
+ MPI_Irecv (&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[m]);
289
+ }
249
290
250
291
// receivers wait for all messages to be received (while sender asynch proceeds)
251
292
MPI_Waitall (requests.size (), requests.data (), MPI_STATUSES_IGNORE);
@@ -626,11 +667,14 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps)
626
667
auto [messageSize, numMessages] = dividePow2PayloadIntoMessages (numAmps);
627
668
vector<MPI_Request> requests (numMessages, MPI_REQUEST_NULL);
628
669
629
- // asynchronously copy 'send' in sendRank over to 'recv' in recvRank
630
- for (qindex m=0 ; m<numMessages; m++)
670
+ // asynchronously copy 'send' in sendRank over to 'recv' in recvRank, using
671
+ // uniquely-tagged messages such that they may arrive out-of-order, enabling AR
672
+ for (qindex m=0 ; m<numMessages; m++) {
673
+ int tag = static_cast <int >(m);
631
674
(myRank == sendRank)?
632
- MPI_Isend (&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, NULL_TAG, MPI_COMM_WORLD, &requests[m]): // sender
633
- MPI_Irecv (&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, NULL_TAG, MPI_COMM_WORLD, &requests[m]); // root
675
+ MPI_Isend (&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, MPI_COMM_WORLD, &requests[m]): // sender
676
+ MPI_Irecv (&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, MPI_COMM_WORLD, &requests[m]); // root
677
+ }
634
678
635
679
// wait for all exchanges to complete (MPI will automatically free the request memory)
636
680
MPI_Waitall (requests.size (), requests.data (), MPI_STATUSES_IGNORE);
0 commit comments