Skip to content

Commit 0195c6a

Browse files
added MPI tags to support AR
By giving each send-recv pair a unique tag, we no longer require message orderedness and can ergo support adaptive routing (AR) --------- Co-authored-by: Oliver Thomson Brown <[email protected]>
1 parent 2cf512a commit 0195c6a

File tree

3 files changed

+77
-19
lines changed

3 files changed

+77
-19
lines changed

quest/src/comm/comm_routines.cpp

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ using std::vector;
8585
* messages and exchanged in parallel / asynchronously. Note
8686
* that 32 is a hard upper limit (before MPI length primitives
8787
* overflow), 29 (at double precision) seems to be the upper-limit
88-
* for UCX exchange, so 28 (to support quad precision) seems safe
88+
* for UCX exchange, so 28 (to support quad precision) seems safe.
89+
* While 2^28 fits in 'int', we use 'qindex' so that arithmetic
90+
* never overflows, and we cast down to 'int' when safe
8991
*/
9092

9193
qindex MAX_MESSAGE_LENGTH = powerOf2(28);
@@ -149,7 +151,27 @@ qindex MAX_MESSAGE_LENGTH = powerOf2(28);
149151
*/
150152

151153

152-
int NULL_TAG = 0;
154+
int getMaxNumMessages() {
155+
#if COMPILE_MPI
156+
157+
// the max supported tag value constrains the total number of messages
158+
// we can send in a round of communication, since we uniquely tag
159+
// each message in a round such that we do not rely upon message-order
160+
// gaurantees and ergo can safely support UCX adaptive routing (AR)
161+
int maxNumMsgs, isAttribSet;
162+
163+
MPI_Comm_get_attr(MPI_COMM_WORLD, MPI_TAG_UB, &maxNumMsgs, &isAttribSet);
164+
165+
if (!isAttribSet)
166+
error_commTagUpperBoundNotSet();
167+
168+
return maxNumMsgs;
169+
170+
#else
171+
error_commButEnvNotDistributed();
172+
return -1;
173+
#endif
174+
}
153175

154176

155177
std::array<qindex,2> dividePow2PayloadIntoMessages(qindex numAmps) {
@@ -159,8 +181,15 @@ std::array<qindex,2> dividePow2PayloadIntoMessages(qindex numAmps) {
159181
if (numAmps < MAX_MESSAGE_LENGTH)
160182
return {numAmps, 1};
161183

162-
// else, payload divides evenly between max-size messages
163-
qindex numMessages = numAmps / MAX_MESSAGE_LENGTH;
184+
// else, payload divides evenly between max-size messages (always fits in int)
185+
qindex numMessages = numAmps / MAX_MESSAGE_LENGTH;
186+
187+
// which we must be able to uniquely tag
188+
if (numMessages > getMaxNumMessages())
189+
error_commNumMessagesExceedTagMax();
190+
191+
// outputs always fit in 'int' but we force them to be 'qindex' since
192+
// caller will multiply them with ints and could easily overflow
164193
return {MAX_MESSAGE_LENGTH, numMessages};
165194
}
166195

@@ -172,8 +201,15 @@ std::array<qindex,3> dividePayloadIntoMessages(qindex numAmps) {
172201
return {numAmps, 1, 0};
173202

174203
// else, use as many max-size messages as possible, and one smaller msg
175-
qindex numMaxSizeMsgs = numAmps / MAX_MESSAGE_LENGTH; // floors
204+
qindex numMaxSizeMsgs = numAmps / MAX_MESSAGE_LENGTH; // floors
176205
qindex remainingMsgSize = numAmps - numMaxSizeMsgs * MAX_MESSAGE_LENGTH;
206+
207+
// all of which we must be able to uniquely tag
208+
if (numMaxSizeMsgs + 1 > getMaxNumMessages())
209+
error_commNumMessagesExceedTagMax();
210+
211+
// outputs always fit in 'int' but we force them to be 'qindex' since
212+
// caller will multiply them with ints and could easily overflow
177213
return {MAX_MESSAGE_LENGTH, numMaxSizeMsgs, remainingMsgSize};
178214
}
179215

@@ -193,11 +229,12 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) {
193229
auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(numElems);
194230
vector<MPI_Request> requests(2*numMessages, MPI_REQUEST_NULL);
195231

196-
// asynchronously exchange the messages (effecting MPI_Isendrecv), exploiting orderedness gaurantee.
197-
// note the exploitation of orderedness means we cannot use UCX's adaptive-routing (AR).
232+
// asynchronously exchange the messages (effecting MPI_Isendrecv), using unique tags
233+
// so that messages are permitted to arrive out-of-order (supporting UCX adaptive-routing)
198234
for (qindex m=0; m<numMessages; m++) {
199-
MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &requests[2*m]);
200-
MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &requests[2*m+1]);
235+
int tag = static_cast<int>(m); // gauranteed int, but m*messageSize needs qindex
236+
MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2*m]);
237+
MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2*m+1]);
201238
}
202239

203240
// wait for all exchanges to complete (MPI will automatically free the request memory)
@@ -224,9 +261,11 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) {
224261
// divide the data into multiple messages
225262
auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(numElems);
226263

227-
// asynchronously send the messages; pairRank receives the same ordering
228-
for (qindex m=0; m<numMessages; m++)
229-
MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &nullReq);
264+
// asynchronously send the uniquely-tagged messages
265+
for (qindex m=0; m<numMessages; m++) {
266+
int tag = static_cast<int>(m); // gauranteed int, but m*messageSize needs qindex
267+
MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &nullReq);
268+
}
230269

231270
#else
232271
error_commButEnvNotDistributed();
@@ -243,9 +282,11 @@ void receiveArray(qcomp* dest, qindex numElems, int pairRank) {
243282
// create a request for each asynch receive below
244283
vector<MPI_Request> requests(numMessages, MPI_REQUEST_NULL);
245284

246-
// listen to receive each message asynchronously (as per arxiv.org/abs/2308.07402)
247-
for (qindex m=0; m<numMessages; m++)
248-
MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, NULL_TAG, MPI_COMM_WORLD, &requests[m]);
285+
// listen to receive each uniquely-tagged message asynchronously (as per arxiv.org/abs/2308.07402)
286+
for (qindex m=0; m<numMessages; m++) {
287+
int tag = static_cast<int>(m); // gauranteed int, but m*messageSize needs qindex
288+
MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[m]);
289+
}
249290

250291
// receivers wait for all messages to be received (while sender asynch proceeds)
251292
MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE);
@@ -626,11 +667,14 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps)
626667
auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(numAmps);
627668
vector<MPI_Request> requests(numMessages, MPI_REQUEST_NULL);
628669

629-
// asynchronously copy 'send' in sendRank over to 'recv' in recvRank
630-
for (qindex m=0; m<numMessages; m++)
670+
// asynchronously copy 'send' in sendRank over to 'recv' in recvRank, using
671+
// uniquely-tagged messages such that they may arrive out-of-order, enabling AR
672+
for (qindex m=0; m<numMessages; m++) {
673+
int tag = static_cast<int>(m);
631674
(myRank == sendRank)?
632-
MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, NULL_TAG, MPI_COMM_WORLD, &requests[m]): // sender
633-
MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, NULL_TAG, MPI_COMM_WORLD, &requests[m]); // root
675+
MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, MPI_COMM_WORLD, &requests[m]): // sender
676+
MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, MPI_COMM_WORLD, &requests[m]); // root
677+
}
634678

635679
// wait for all exchanges to complete (MPI will automatically free the request memory)
636680
MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE);

quest/src/core/errors.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ void error_commGivenInconsistentNumSubArraysANodes() {
145145
raiseInternalError("A distributed function was given a different number of per-node subarray lengths than exist nodes.");
146146
}
147147

148+
void error_commTagUpperBoundNotSet() {
149+
150+
raiseInternalError("The MPI attribute MPI_TAG_UB was not set for communicator MPI_COMM_WORLD, such that the maximum number of messages per communication-round could not be determined.");
151+
}
152+
153+
void error_commNumMessagesExceedTagMax() {
154+
155+
raiseInternalError("A function attempted to communicate via more messages than permitted (since there would be more uniquely-tagged messages than the tag upperbound).");
156+
}
157+
148158
void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps) {
149159

150160
bool valid = (

quest/src/core/errors.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ void error_commWithSameRank();
7070

7171
void error_commGivenInconsistentNumSubArraysANodes();
7272

73+
void error_commTagUpperBoundNotSet();
74+
75+
void error_commNumMessagesExceedTagMax();
76+
7377
void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps);
7478

7579
void assert_commPayloadIsPowerOf2(qindex numAmps);

0 commit comments

Comments
 (0)