Skip to content

Commit 3a19b9b

Browse files
authored
Merge pull request #194 from ROSS-org/fix-now-trigger
Fixing GVT hook trigger when called by LPs
2 parents f27cff5 + 241b2ed commit 3a19b9b

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

core/gvt/mpi_allreduce.c

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ tw_gvt_step2(tw_pe *me)
211211
tw_copy_event_sig(&pq_min_sig, tw_pq_minimum_sig_ptr(me->pq));
212212
tw_copy_event_sig(&net_min_sig, tw_net_minimum_sig_ptr());
213213

214-
lvt_sig = me->trans_msg_sig;
214+
tw_copy_event_sig(&lvt_sig, &me->trans_msg_sig);
215215
if(tw_event_sig_compare_ptr(&lvt_sig, &pq_min_sig) > 0)
216216
{
217217
tw_copy_event_sig(&lvt_sig, &pq_min_sig);
@@ -220,6 +220,12 @@ tw_gvt_step2(tw_pe *me)
220220
{
221221
tw_copy_event_sig(&lvt_sig, &net_min_sig);
222222
}
223+
if(g_tw_gvt_hook
224+
&& g_tw_gvt_hook_trigger.status
225+
&& tw_event_sig_compare_ptr(&lvt_sig, &g_tw_gvt_hook_trigger.sig_at) > 0)
226+
{
227+
tw_copy_event_sig(&lvt_sig, &g_tw_gvt_hook_trigger.sig_at);
228+
}
223229

224230
all_reduce_cnt++;
225231
if(MPI_Allreduce(
@@ -517,6 +523,7 @@ void tw_trigger_gvt_hook_every(int num_gvt_calls) {
517523
g_tw_gvt_hook_trigger.status = GVT_HOOK_STATUS_every_n_gvt;
518524
g_tw_gvt_hook_trigger.every_n_gvt.starting_at = g_tw_gvt_done;
519525
g_tw_gvt_hook_trigger.every_n_gvt.nums = num_gvt_calls;
526+
tw_copy_event_sig(&g_tw_gvt_hook_trigger.sig_at, &g_tw_max_sig);
520527
}
521528

522529
void tw_trigger_gvt_hook_when_model_calls(void) {
@@ -532,8 +539,14 @@ void tw_trigger_gvt_hook_now(tw_lp * lp) {
532539
if (g_tw_gvt_hook_trigger.status != GVT_HOOK_STATUS_model_call) {
533540
tw_error(TW_LOC, "`tw_trigger_gvt_hook_now` called but `g_tw_gvt_hook_trigger.status != GVT_HOOK_STATUS_model_call`. Either `tw_trigger_gvt_hook_when_model_calls` was not called or another trigger function has been");
534541
}
542+
if (g_tw_gvt_hook_trigger.sig_at.tie_lineage_length >= MAX_TIE_CHAIN) {
543+
tw_error(TW_LOC, "Maximum zero-offset tie chain reached (%d), increase #define in ross-types.h", MAX_TIE_CHAIN);
544+
}
535545
tw_event_sig * now = &lp->kp->last_sig; // tw_now_sig(lp);
536546
tw_copy_event_sig(&g_tw_gvt_hook_trigger.sig_at, now);
547+
// We store as the trigger time the next valid, larger tiebreaker signature. It is unlikely we will this will tie with any other signature
548+
g_tw_gvt_hook_trigger.sig_at.event_tiebreaker[g_tw_gvt_hook_trigger.sig_at.tie_lineage_length] = 0;
549+
g_tw_gvt_hook_trigger.sig_at.tie_lineage_length += 1;
537550

538551
// Forcing GVT to happen now (possibly triggering gvt hook)
539552
lp->pe->gvt_status = TW_GVT_COMPUTE; // same behavior as if calling `tw_gvt_force_update()`

core/ross-gvt-internal.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,19 @@ enum GVT_HOOK_STATUS {
8383
// Holds one timestamp at which to trigger the arbitrary function
8484
struct gvt_hook_trigger {
8585
enum GVT_HOOK_STATUS status;
86-
union {
87-
// GVT_HOOK_TYPE_timestamp and GVT_HOOK_STATUS_model_call
88-
struct {
86+
// GVT_HOOK_TYPE_timestamp and GVT_HOOK_STATUS_model_call
87+
struct {
8988
#ifdef USE_RAND_TIEBREAKER
90-
tw_event_sig sig_at;
89+
tw_event_sig sig_at;
9190
#else
92-
tw_stime at;
91+
tw_stime at;
9392
#endif
94-
};
95-
// GVT_HOOK_TYPE_every_n_gvt
96-
struct {
97-
int starting_at;
98-
int nums;
99-
} every_n_gvt;
10093
};
94+
// GVT_HOOK_TYPE_every_n_gvt
95+
struct {
96+
int starting_at;
97+
int nums;
98+
} every_n_gvt;
10199
};
102100

103101
extern struct gvt_hook_trigger g_tw_gvt_hook_trigger;

core/tw-sched.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ static inline void tw_gvt_hook_step(tw_pe * me) {
601601
}
602602
break;
603603
case GVT_HOOK_STATUS_model_call: {
604-
bool const triggered_here = tw_event_sig_compare_ptr(&me->GVT_sig, &g_tw_gvt_hook_trigger.sig_at) > 0;
604+
bool const triggered_here = tw_event_sig_compare_ptr(&me->GVT_sig, &g_tw_gvt_hook_trigger.sig_at) >= 0;
605605
bool const triggered_somewhere = does_any_pe(triggered_here);
606606
if (triggered_somewhere) {
607607
// LP has triggered GVT hook

0 commit comments

Comments
 (0)