@@ -63,7 +63,7 @@ def _find_fist_opt_user(self, main_program):
63
63
def _get_comm_group (self , ranks = []):
64
64
ranks = sorted (ranks )
65
65
if tuple (ranks ) in self .comm_group :
66
- return self .comm_group [tuple (ranks )]. id
66
+ return self .comm_group [tuple (ranks )]
67
67
# The communication group of this `all_reduce` op satisfies len (ranks)==2.
68
68
# When `force_new_group=False` is set, the `send&recv` group will be returned,
69
69
# At this point, `all_reduce` and `send&recv` share the same group, and
@@ -205,6 +205,14 @@ def sync_shared_parameters(self, main_program, startup_program):
205
205
logger .info ("No parameter need to share, skip pass." )
206
206
return []
207
207
208
+ # Must initialize the redundant communication group for the allreduce op here.
209
+ # Otherwise, it will hang during gradient synchronization.
210
+ for idx in range (len (self .src_ranks )):
211
+ rank_1 = self .src_ranks [idx ]
212
+ rank_2 = self .dst_ranks [idx ]
213
+ new_process_group (sorted ([rank_1 , rank_2 ]))
214
+ self ._get_comm_group ([rank_1 , rank_2 ])
215
+
208
216
return new_shared_params
209
217
210
218
def sync_shared_parameter_gradient (
@@ -228,6 +236,9 @@ def sync_shared_parameter_gradient(
228
236
229
237
cur_rank = paddle .distributed .get_rank ()
230
238
239
+ if cur_rank not in self .src_ranks and cur_rank not in self .dst_ranks :
240
+ return params_grads
241
+
231
242
pre_name = ""
232
243
if cur_rank in self .dst_ranks :
233
244
pre_name = "shared_"
0 commit comments