Skip to content

Commit 714b731

Browse files
Converse: Strengthen locking in node reductions (#3481)
It is possible for node reduction messages to arrive at the root before it has completed its own CmiNodeReduce* call. Therefore, lock around all accesses of _nodereduce_info, and make _nodereduce_seqID* atomic.
1 parent 0746e1b commit 714b731

File tree

2 files changed

+91
-40
lines changed

2 files changed

+91
-40
lines changed

src/conv-core/convcore.C

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,9 +2414,8 @@ void CmiSyncVectorSendAndFree(int destPE, int n, int *sizes, char **msgs) {
24142414
* merge call will not be deleted by the system, and the CmiHandler function
24152415
* will be in charge of its deletion.
24162416
*
2417-
* CmiReduce/CmiReduceStruct MUST be called once by every processor,
2418-
* CmiNodeReduce/CmiNodeReduceStruct MUST be called once by every node, and in
2419-
* particular by the rank zero in each node.
2417+
* CmiReduce/CmiReduceStruct MUST be called once by every processor.
2418+
* CmiNodeReduce/CmiNodeReduceStruct MUST be called once by every node.
24202419
****************************************************************************/
24212420

24222421
#define REDUCTION_DEBUG 0
@@ -2454,6 +2453,20 @@ struct CmiNodeReduction {
24542453
CmiReduction * red;
24552454
};
24562455

2456+
static inline CmiReductionID CmiReductionIDFetchAdd(CmiReductionID & id, CmiReductionID addend) {
2457+
const CmiReductionID oldid = id;
2458+
id = oldid + addend;
2459+
return oldid;
2460+
}
2461+
#if CMK_SMP
2462+
static inline CmiReductionID CmiReductionIDFetchAdd(std::atomic<CmiReductionID> & id, CmiReductionID addend) {
2463+
return id.fetch_add(addend);
2464+
}
2465+
using CmiNodeReductionID = std::atomic<CmiReductionID>;
2466+
#else
2467+
using CmiNodeReductionID = CmiReductionID;
2468+
#endif
2469+
24572470
CpvStaticDeclare(int, CmiReductionMessageHandler);
24582471
CpvStaticDeclare(int, CmiReductionDynamicRequestHandler);
24592472

@@ -2466,20 +2479,27 @@ CpvStaticDeclare(CmiReductionID, _reduce_seqID_request);
24662479
CpvStaticDeclare(CmiReductionID, _reduce_seqID_dynamic);
24672480

24682481
CsvStaticDeclare(CmiNodeReduction *, _nodereduce_info);
2469-
CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_global);
2470-
CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_request);
2471-
CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_dynamic);
2482+
CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_global);
2483+
CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_request);
2484+
CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_dynamic);
24722485

24732486
enum : CmiReductionID {
24742487
CmiReductionID_globalOffset = 0, /* Reductions that involve the whole set of processors */
24752488
CmiReductionID_requestOffset = 1, /* Reductions IDs that are requested by all the processors (i.e during intialization) */
24762489
CmiReductionID_dynamicOffset = 2, /* Reductions IDs that are requested by only one processor (typically at runtime) */
2477-
CmiReductionID_multiplier = 3
2490+
2491+
CmiReductionID_multiplier = 4
24782492
};
24792493

2494+
static_assert(CmiIsPow2(CmiReductionID_multiplier),
2495+
"CmiReductionID_multiplier must be a power of two because seqID counters may overflow and wrap to 0");
2496+
2497+
static inline unsigned int CmiGetRedIndex(CmiReductionID id) {
2498+
return id & ~((~0u) << CmiLogMaxReductions);
2499+
}
2500+
24802501
static CmiReduction* CmiGetReductionCreate(int id, short int numChildren) {
2481-
const int idx = id & ~((~0u) << CmiLogMaxReductions);
2482-
auto & redref = CpvAccess(_reduce_info)[idx];
2502+
auto & redref = CpvAccess(_reduce_info)[CmiGetRedIndex(id)];
24832503
CmiReduction *red = redref;
24842504
if (red != NULL && red->seqID != id) {
24852505
/* The table needs to be expanded */
@@ -2506,33 +2526,27 @@ static CmiReduction* CmiGetReductionCreate(int id, short int numChildren) {
25062526
}
25072527

25082528
static void CmiClearReduction(int id) {
2509-
const int idx = id & ~((~0u) << CmiLogMaxReductions);
2510-
auto & redref = CpvAccess(_reduce_info)[idx];
2529+
auto & redref = CpvAccess(_reduce_info)[CmiGetRedIndex(id)];
25112530
auto red = redref;
25122531
redref = nullptr;
25132532
free(red);
25142533
}
25152534

2516-
static CmiReduction* CmiGetNextReduction(short int numChildren) {
2517-
int id = CpvAccess(_reduce_seqID_global);
2518-
int newid = id + CmiReductionID_multiplier;
2519-
if (id > 0xFFF0) newid = CmiReductionID_globalOffset;
2520-
CpvAccess(_reduce_seqID_global) = newid;
2521-
return CmiGetReductionCreate(id, numChildren);
2535+
static CmiReductionID CmiGetNextReductionID(void) {
2536+
return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_global), CmiReductionID_multiplier);
25222537
}
25232538

25242539
CmiReductionID CmiGetGlobalReduction(void) {
2525-
return CpvAccess(_reduce_seqID_request)+=CmiReductionID_multiplier;
2540+
return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_request), CmiReductionID_multiplier);
25262541
}
25272542

25282543
CmiReductionID CmiGetDynamicReduction(void) {
25292544
if (CmiMyPe() != 0) CmiAbort("Cannot call CmiGetDynamicReduction on processors other than zero!\n");
2530-
return CpvAccess(_reduce_seqID_dynamic)+=CmiReductionID_multiplier;
2545+
return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_dynamic), CmiReductionID_multiplier);
25312546
}
25322547

25332548
static CmiReduction* CmiGetNodeReductionCreate(int id, short int numChildren) {
2534-
const int idx = id & ~((~0u) << CmiLogMaxReductions);
2535-
auto & redref = CsvAccess(_nodereduce_info)[idx].red;
2549+
auto & redref = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)].red;
25362550
CmiReduction *red = redref;
25372551
if (red != NULL && red->seqID != id) {
25382552
/* The table needs to be expanded */
@@ -2559,28 +2573,23 @@ static CmiReduction* CmiGetNodeReductionCreate(int id, short int numChildren) {
25592573
}
25602574

25612575
static void CmiClearNodeReduction(int id) {
2562-
const int idx = id & ~((~0u) << CmiLogMaxReductions);
2563-
auto & redref = CsvAccess(_nodereduce_info)[idx].red;
2576+
auto & redref = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)].red;
25642577
auto red = redref;
25652578
redref = nullptr;
25662579
free(red);
25672580
}
25682581

2569-
static CmiReduction* CmiGetNextNodeReduction(short int numChildren) {
2570-
int id = CsvAccess(_nodereduce_seqID_global);
2571-
int newid = id + CmiReductionID_multiplier;
2572-
if (id > 0xFFF0) newid = CmiReductionID_globalOffset;
2573-
CsvAccess(_nodereduce_seqID_global) = newid;
2574-
return CmiGetNodeReductionCreate(id, numChildren);
2582+
static CmiReductionID CmiGetNextNodeReductionID(void) {
2583+
return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_global), CmiReductionID_multiplier);
25752584
}
25762585

25772586
CmiReductionID CmiGetGlobalNodeReduction(void) {
2578-
return CsvAccess(_nodereduce_seqID_request)+=CmiReductionID_multiplier;
2587+
return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_request), CmiReductionID_multiplier);
25792588
}
25802589

25812590
CmiReductionID CmiGetDynamicNodeReduction(void) {
25822591
if (CmiMyNode() != 0) CmiAbort("Cannot call CmiGetDynamicNodeReduction on nodes other than zero!\n");
2583-
return CsvAccess(_nodereduce_seqID_dynamic)+=CmiReductionID_multiplier;
2592+
return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_dynamic), CmiReductionID_multiplier);
25842593
}
25852594

25862595
void CmiReductionHandleDynamicRequest(char *msg) {
@@ -2797,14 +2806,16 @@ static void CmiGlobalNodeReduceStruct(void *data, CmiReducePupFn pupFn,
27972806
}
27982807

27992808
void CmiReduce(void *msg, int size, CmiReduceMergeFn mergeFn) {
2800-
CmiReduction *red = CmiGetNextReduction(CmiNumSpanTreeChildren(CmiMyPe()));
2809+
const CmiReductionID id = CmiGetNextReductionID();
2810+
CmiReduction *red = CmiGetReductionCreate(id, CmiNumSpanTreeChildren(CmiMyPe()));
28012811
CmiGlobalReduce(msg, size, mergeFn, red);
28022812
}
28032813

28042814
void CmiReduceStruct(void *data, CmiReducePupFn pupFn,
28052815
CmiReduceMergeFn mergeFn, CmiHandler dest,
28062816
CmiReduceDeleteFn deleteFn) {
2807-
CmiReduction *red = CmiGetNextReduction(CmiNumSpanTreeChildren(CmiMyPe()));
2817+
const CmiReductionID id = CmiGetNextReductionID();
2818+
CmiReduction *red = CmiGetReductionCreate(id, CmiNumSpanTreeChildren(CmiMyPe()));
28082819
CmiGlobalReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red);
28092820
}
28102821

@@ -2884,27 +2895,65 @@ void CmiGroupReduceStruct(CmiGroup grp, void *data, CmiReducePupFn pupFn,
28842895
}
28852896

28862897
void CmiNodeReduce(void *msg, int size, CmiReduceMergeFn mergeFn) {
2887-
CmiReduction *red = CmiGetNextNodeReduction(CmiNumNodeSpanTreeChildren(CmiMyNode()));
2898+
const CmiReductionID id = CmiGetNextNodeReductionID();
2899+
#if CMK_SMP
2900+
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
2901+
CmiLock(nodered.lock);
2902+
#endif
2903+
2904+
CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
28882905
CmiGlobalNodeReduce(msg, size, mergeFn, red);
2906+
2907+
#if CMK_SMP
2908+
CmiUnlock(nodered.lock);
2909+
#endif
28892910
}
28902911

28912912
void CmiNodeReduceStruct(void *data, CmiReducePupFn pupFn,
28922913
CmiReduceMergeFn mergeFn, CmiHandler dest,
28932914
CmiReduceDeleteFn deleteFn) {
2894-
CmiReduction *red = CmiGetNextNodeReduction(CmiNumNodeSpanTreeChildren(CmiMyNode()));
2915+
const CmiReductionID id = CmiGetNextNodeReductionID();
2916+
#if CMK_SMP
2917+
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
2918+
CmiLock(nodered.lock);
2919+
#endif
2920+
2921+
CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
28952922
CmiGlobalNodeReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red);
2923+
2924+
#if CMK_SMP
2925+
CmiUnlock(nodered.lock);
2926+
#endif
28962927
}
28972928

28982929
void CmiNodeReduceID(void *msg, int size, CmiReduceMergeFn mergeFn, CmiReductionID id) {
2930+
#if CMK_SMP
2931+
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
2932+
CmiLock(nodered.lock);
2933+
#endif
2934+
28992935
CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
29002936
CmiGlobalNodeReduce(msg, size, mergeFn, red);
2937+
2938+
#if CMK_SMP
2939+
CmiUnlock(nodered.lock);
2940+
#endif
29012941
}
29022942

29032943
void CmiNodeReduceStructID(void *data, CmiReducePupFn pupFn,
29042944
CmiReduceMergeFn mergeFn, CmiHandler dest,
29052945
CmiReduceDeleteFn deleteFn, CmiReductionID id) {
2946+
#if CMK_SMP
2947+
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
2948+
CmiLock(nodered.lock);
2949+
#endif
2950+
29062951
CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
29072952
CmiGlobalNodeReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red);
2953+
2954+
#if CMK_SMP
2955+
CmiUnlock(nodered.lock);
2956+
#endif
29082957
}
29092958

29102959
static void CmiHandleReductionMessage(void *msg) {
@@ -2921,8 +2970,7 @@ static void CmiHandleReductionMessage(void *msg) {
29212970
static void CmiHandleNodeReductionMessage(void *msg) {
29222971
const auto id = CmiGetRedID(msg);
29232972
#if CMK_SMP
2924-
const int idx = id & ~((~0u) << CmiLogMaxReductions);
2925-
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[idx];
2973+
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
29262974
CmiLock(nodered.lock);
29272975
#endif
29282976

@@ -2964,11 +3012,11 @@ void CmiReductionsInit(void) {
29643012

29653013
if (CmiMyRank() == 0)
29663014
{
2967-
CsvInitialize(CmiReductionID, _nodereduce_seqID_global);
3015+
CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_global);
29683016
CsvAccess(_nodereduce_seqID_global) = CmiReductionID_globalOffset;
2969-
CsvInitialize(CmiReductionID, _nodereduce_seqID_request);
3017+
CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_request);
29703018
CsvAccess(_nodereduce_seqID_request) = CmiReductionID_requestOffset;
2971-
CsvInitialize(CmiReductionID, _nodereduce_seqID_dynamic);
3019+
CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_dynamic);
29723020
CsvAccess(_nodereduce_seqID_dynamic) = CmiReductionID_dynamicOffset;
29733021
CsvInitialize(CmiNodeReduction *, _nodereduce_info);
29743022
auto noderedinfo = (CmiNodeReduction *)malloc(CmiMaxReductions * sizeof(CmiNodeReduction));

src/conv-core/converse.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@
7878

7979
#define CMI_MSG_NOKEEP(msg) ((CmiMsgHeaderBasic *)msg)->nokeep
8080

81+
#define CmiIsPow2OrZero(v) (((v) & ((v) - 1)) == 0)
82+
#define CmiIsPow2(v) (CmiIsPow2OrZero(v) && (v))
83+
8184
#define CMIALIGN(x,n) (size_t)((~((size_t)n-1))&((x)+(n-1)))
8285
/*#define ALIGN8(x) (size_t)((~7)&((x)+7)) */
8386
#define ALIGN8(x) CMIALIGN(x,8)

0 commit comments

Comments
 (0)