Skip to content

Commit a04a30a

Browse files
awaelchlikaushikb11
authored andcommitted
fix NCCL error with non-consecutive trainer gpus (#8165)
* device ids in barrier x x s same fix for spawn fix non-nccl x * add changelog * get nccl backend * get backend Co-authored-by: Kaushik B <[email protected]>
1 parent eb1356a commit a04a30a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,12 @@ def pre_dispatch(self):
295295
def post_dispatch(self) -> None:
296296
self.cluster_environment.teardown()
297297

298-
def barrier(self, *args, **kwargs):
299-
if torch_distrib.is_available() and torch_distrib.is_initialized():
298+
def barrier(self, *args, **kwargs) -> None:
299+
if not torch_distrib.is_initialized():
300+
return
301+
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
302+
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
303+
else:
300304
torch_distrib.barrier()
301305

302306
def broadcast(self, obj: object, src: int = 0) -> object:

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,12 @@ def __recover_child_process_weights(self, best_path, last_path):
271271
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
272272
self.lightning_module.load_state_dict(ckpt)
273273

274-
def barrier(self, *args, **kwargs):
275-
if torch_distrib.is_initialized():
274+
def barrier(self, *args, **kwargs) -> None:
275+
if not torch_distrib.is_initialized():
276+
return
277+
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
278+
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
279+
else:
276280
torch_distrib.barrier()
277281

278282
def broadcast(self, obj: object, src: int = 0) -> object:

0 commit comments

Comments
 (0)