Skip to content

Commit aa72149

Browse files
authored
[PIR-Auto-Parallel] fix comm group hang in sync shared param pass (#71524) (#71613)
1 parent 6320b01 commit aa72149

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

python/paddle/distributed/passes/auto_parallel_sync_shared_params.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _find_fist_opt_user(self, main_program):
6363
def _get_comm_group(self, ranks=[]):
6464
ranks = sorted(ranks)
6565
if tuple(ranks) in self.comm_group:
66-
return self.comm_group[tuple(ranks)].id
66+
return self.comm_group[tuple(ranks)]
6767
# The communication group of this `all_reduce` op satisfies len (ranks)==2.
6868
# When `force_new_group=False` is set, the `send&recv` group will be returned,
6969
# At this point, `all_reduce` and `send&recv` share the same group, and
@@ -205,6 +205,14 @@ def sync_shared_parameters(self, main_program, startup_program):
205205
logger.info("No parameter need to share, skip pass.")
206206
return []
207207

208+
# Must initialize the redundant communication group for the allreduce op here.
209+
# Otherwise, it will hang during gradient synchronization.
210+
for idx in range(len(self.src_ranks)):
211+
rank_1 = self.src_ranks[idx]
212+
rank_2 = self.dst_ranks[idx]
213+
new_process_group(sorted([rank_1, rank_2]))
214+
self._get_comm_group([rank_1, rank_2])
215+
208216
return new_shared_params
209217

210218
def sync_shared_parameter_gradient(
@@ -228,6 +236,9 @@ def sync_shared_parameter_gradient(
228236

229237
cur_rank = paddle.distributed.get_rank()
230238

239+
if cur_rank not in self.src_ranks and cur_rank not in self.dst_ranks:
240+
return params_grads
241+
231242
pre_name = ""
232243
if cur_rank in self.dst_ranks:
233244
pre_name = "shared_"

0 commit comments

Comments
 (0)