Skip to content

Commit 26c80c5

Browse files
awaelchlilexierule
authored andcommitted
resurface lost ddp info message (#8111)
1 parent f56df26 commit 26c80c5

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
_TORCH_GREATER_EQUAL_1_8,
3636
rank_zero_warn,
3737
)
38-
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
38+
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
3939
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4040
from pytorch_lightning.utilities.seed import reset_seed
4141

@@ -197,13 +197,6 @@ def setup_distributed(self):
197197
# where to store ip_table
198198
self.init_ddp_connection()
199199

200-
# on world_size=0 let everyone know training is starting
201-
if self.is_global_zero and not torch.distributed.is_initialized():
202-
log.info("-" * 100)
203-
log.info(f"distributed_backend={self.distributed_backend}")
204-
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
205-
log.info("-" * 100)
206-
207200
# set the ranks and devices
208201
self.dist.rank = self.global_rank
209202
self.dist.device = self.root_device
@@ -271,6 +264,14 @@ def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Opt
271264
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
272265
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
273266

267+
# on rank=0 let everyone know training is starting
268+
rank_zero_info(
269+
f"{'-' * 100}\n"
270+
f"distributed_backend={self.torch_distributed_backend}\n"
271+
f"All DDP processes registered. Starting ddp with {self.world_size} processes\n"
272+
f"{'-' * 100}\n"
273+
)
274+
274275
def pre_dispatch(self):
275276
# move the model to the correct device
276277
self.model_to_device()

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
3232
from pytorch_lightning.utilities.cloud_io import atomic_save
3333
from pytorch_lightning.utilities.cloud_io import load as pl_load
34-
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
34+
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
3535
from pytorch_lightning.utilities.seed import reset_seed
3636

3737
if _TORCH_GREATER_EQUAL_1_8:
@@ -148,13 +148,6 @@ def new_process(self, process_idx, trainer, mp_queue):
148148
# ... need to double check that it is the correct place
149149
# self.trainer.call_setup_hook(self.model)
150150

151-
# on world_size=0 let everyone know training is starting
152-
if self.is_global_zero and not torch.distributed.is_initialized():
153-
log.info("-" * 100)
154-
log.info(f"distributed_backend={self.distributed_backend}")
155-
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
156-
log.info("-" * 100)
157-
158151
# set the ranks and devices
159152
self.dist.rank = self.global_rank
160153
self.dist.device = self.root_device
@@ -230,6 +223,14 @@ def init_ddp_connection(self, global_rank: Optional[int], world_size: Optional[i
230223
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
231224
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
232225

226+
# on rank=0 let everyone know training is starting
227+
rank_zero_info(
228+
f"{'-' * 100}\n"
229+
f"distributed_backend={self.torch_distributed_backend}\n"
230+
f"All DDP processes registered. Starting ddp with {self.world_size} processes\n"
231+
f"{'-' * 100}\n"
232+
)
233+
233234
def determine_ddp_device_ids(self):
234235
if self.root_device.type == "cpu":
235236
return None

0 commit comments

Comments
 (0)