@@ -56,9 +56,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_do_global_progress(void)
5656
5757/* define MPIDI_PROGRESS to make the code more readable (to avoid nested '#ifdef's) */
5858#ifdef MPIDI_CH4_DIRECT_NETMOD
59- #define MPIDI_PROGRESS (vci ) \
59+ #define MPIDI_PROGRESS (vci , is_global ) \
6060 do { \
61- if (state->flag & MPIDI_PROGRESS_NM && !made_progress) { \
61+ if (state->flag & MPIDI_PROGRESS_NM && (is_global || !made_progress) ) { \
6262 MPIDI_THREAD_CS_ENTER_VCI_OPTIONAL(vci); \
6363 mpi_errno = MPIDI_NM_progress(vci, &made_progress); \
6464 MPIDI_THREAD_CS_EXIT_VCI_OPTIONAL(vci); \
@@ -67,15 +67,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_do_global_progress(void)
6767 } while (0)
6868
6969#else
70- #define MPIDI_PROGRESS (vci ) \
70+ #define MPIDI_PROGRESS (vci , is_global ) \
7171 do { \
72- if (state->flag & MPIDI_PROGRESS_SHM && !made_progress) { \
72+ if (state->flag & MPIDI_PROGRESS_SHM && (is_global || !made_progress) ) { \
7373 MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); \
7474 mpi_errno = MPIDI_SHM_progress(vci, &made_progress); \
7575 MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); \
7676 MPIR_ERR_CHECK(mpi_errno); \
7777 } \
78- if (state->flag & MPIDI_PROGRESS_NM && !made_progress) { \
78+ if (state->flag & MPIDI_PROGRESS_NM && (is_global || !made_progress) ) { \
7979 MPIDI_THREAD_CS_ENTER_VCI_OPTIONAL(vci); \
8080 mpi_errno = MPIDI_NM_progress(vci, &made_progress); \
8181 MPIDI_THREAD_CS_EXIT_VCI_OPTIONAL(vci); \
@@ -125,21 +125,21 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_progress_test(MPID_Progress_state * state)
125125
126126#if MPIDI_CH4_MAX_VCIS == 1
127127 /* fast path for single vci */
128- MPIDI_PROGRESS (0 );
128+ MPIDI_PROGRESS (0 , false );
129129#else
130130 /* multiple vci */
131131 bool is_explicit_vci = (state -> vci_count == 1 && MPIDI_VCI_IS_EXPLICIT (state -> vci [0 ]));
132132 if (!is_explicit_vci && MPIDI_do_global_progress ()) {
133133 for (int vci = 0 ; vci < MPIDI_global .n_vcis ; vci ++ ) {
134- MPIDI_PROGRESS (vci );
134+ MPIDI_PROGRESS (vci , true );
135135 }
136136 } else {
137137 for (int i = 0 ; i < state -> vci_count ; i ++ ) {
138138 int vci = state -> vci [i ];
139139 if (vci >= MPIDI_global .n_total_vcis ) {
140140 continue ;
141141 }
142- MPIDI_PROGRESS (vci );
142+ MPIDI_PROGRESS (vci , false );
143143 }
144144 }
145145#endif
0 commit comments