@@ -214,7 +214,7 @@ def set_world_ranks(self) -> None:
214214 rank_zero_only .rank = utils_rank_zero_only .rank = self .global_rank
215215
216216 def _register_ddp_hooks (self ) -> None :
217- log .debug (f"{ self .__class__ .__name__ } : registering ddp hooks" )
217+ log .debug (f"{ self .__class__ .__name__ } : registering DDP hooks" )
218218 # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
219219 # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
220220 if self .root_device .type == "cuda" :
@@ -431,25 +431,25 @@ def _setup_model(self, model: Module) -> Module:
431431 if isinstance (module , Module ):
432432 ddp_module = DistributedDataParallel (module , device_ids = device_ids , ** self ._ddp_kwargs )
433433 setattr (model , name , ddp_module )
434-
435434 return model
436435
437436 @override
438437 def _register_ddp_hooks (self ) -> None :
439- log .debug (f"{ self .__class__ .__name__ } : registering ddp hooks" )
438+ log .debug (f"{ self .__class__ .__name__ } : registering DDP hooks" )
440439 # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
441440 # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
442- if self .root_device .type == "cuda" :
443- assert isinstance (self .model , Module )
444-
445- for name , module in self .model .named_children ():
446- assert isinstance (module , DistributedDataParallel )
447- _register_ddp_comm_hook (
448- model = module ,
449- ddp_comm_state = self ._ddp_comm_state ,
450- ddp_comm_hook = self ._ddp_comm_hook ,
451- ddp_comm_wrapper = self ._ddp_comm_wrapper ,
452- )
441+ if self .root_device .type != "cuda" :
442+ return
443+ assert isinstance (self .model , Module )
444+
445+ for name , module in self .model .named_children ():
446+ assert isinstance (module , DistributedDataParallel )
447+ _register_ddp_comm_hook (
448+ model = module ,
449+ ddp_comm_state = self ._ddp_comm_state ,
450+ ddp_comm_hook = self ._ddp_comm_hook ,
451+ ddp_comm_wrapper = self ._ddp_comm_wrapper ,
452+ )
453453
454454
455455class _DDPForwardRedirection (_ForwardRedirection ):
0 commit comments