|
18 | 18 | from typing import Any, Optional, Union |
19 | 19 |
|
20 | 20 | import torch |
21 | | -import torch.distributed as dist |
22 | 21 | from typing_extensions import override |
23 | 22 |
|
24 | 23 | import lightning.pytorch as pl |
@@ -258,13 +257,13 @@ def _broadcast_sigterm_tensor(self): |
258 | 257 | [1 if getattr(self.trainer, "received_sigterm", False) else 0], |
259 | 258 | device=self.trainer.strategy.root_device, |
260 | 259 | ) |
261 | | - dist.broadcast(sigterm_tensor, src=0) |
| 260 | + torch.distributed.broadcast(sigterm_tensor, src=0) |
262 | 261 | except Exception: |
263 | 262 | sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device) |
264 | 263 |
|
265 | 264 | if sigterm_tensor.item() == 1: |
266 | 265 | with contextlib.suppress(Exception): |
267 | | - dist.barrier() # prevent deadlocks by syncing all ranks before exit |
| 266 | + torch.distributed.barrier() # prevent deadlocks by syncing all ranks before exit |
268 | 267 | raise SIGTERMException() |
269 | 268 |
|
270 | 269 | def advance(self, data_fetcher: _DataFetcher) -> None: |
@@ -292,7 +291,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: |
292 | 291 |
|
293 | 292 | # ===================================================================== |
294 | 293 |
|
295 | | - if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1: |
| 294 | + if torch.distributed.is_available() and torch.distributed.is_initialized() and self.trainer.world_size > 1: |
296 | 295 | self._broadcast_sigterm_tensor() |
297 | 296 |
|
298 | 297 | # ===================================================================== |
|
0 commit comments