Skip to content

Commit ece7d38

Browse files
committed
formating
1 parent ec62397 commit ece7d38

File tree

1 file changed

+14
-14
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+14
-14
lines changed

src/lightning/pytorch/strategies/ddp.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def set_world_ranks(self) -> None:
214214
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
215215

216216
def _register_ddp_hooks(self) -> None:
217-
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
217+
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
218218
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
219219
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
220220
if self.root_device.type == "cuda":
@@ -431,25 +431,25 @@ def _setup_model(self, model: Module) -> Module:
431431
if isinstance(module, Module):
432432
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs)
433433
setattr(model, name, ddp_module)
434-
435434
return model
436435

437436
@override
438437
def _register_ddp_hooks(self) -> None:
439-
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
438+
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
440439
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
441440
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
442-
if self.root_device.type == "cuda":
443-
assert isinstance(self.model, Module)
444-
445-
for name, module in self.model.named_children():
446-
assert isinstance(module, DistributedDataParallel)
447-
_register_ddp_comm_hook(
448-
model=module,
449-
ddp_comm_state=self._ddp_comm_state,
450-
ddp_comm_hook=self._ddp_comm_hook,
451-
ddp_comm_wrapper=self._ddp_comm_wrapper,
452-
)
441+
if self.root_device.type != "cuda":
442+
return
443+
assert isinstance(self.model, Module)
444+
445+
for name, module in self.model.named_children():
446+
assert isinstance(module, DistributedDataParallel)
447+
_register_ddp_comm_hook(
448+
model=module,
449+
ddp_comm_state=self._ddp_comm_state,
450+
ddp_comm_hook=self._ddp_comm_hook,
451+
ddp_comm_wrapper=self._ddp_comm_wrapper,
452+
)
453453

454454

455455
class _DDPForwardRedirection(_ForwardRedirection):

0 commit comments

Comments
 (0)