|
16 | 16 | import base64 |
17 | 17 | import glob |
18 | 18 | import json |
19 | | -import logging |
20 | 19 | import os |
21 | 20 | import subprocess |
22 | 21 | import tempfile |
|
30 | 29 | from invoke.context import Context |
31 | 30 |
|
32 | 31 | from nemo_run.config import get_nemorun_home |
| 32 | +from nemo_run.core.console import CONSOLE |
| 33 | +from nemo_run.core.constants import RUNDIR_NAME |
33 | 34 | from nemo_run.core.execution.base import Executor, ExecutorMacros |
34 | | -from nemo_run.core.execution.launcher import FaultTolerance, Launcher |
| 35 | +from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun |
35 | 36 | from nemo_run.core.execution.utils import fill_template |
36 | 37 | from nemo_run.core.packaging.base import Packager |
37 | 38 | from nemo_run.core.packaging.git import GitArchivePackager |
38 | 39 |
|
39 | | -logger = logging.getLogger(__name__) |
40 | | - |
41 | 40 |
|
42 | 41 | class DGXCloudState(Enum): |
43 | 42 | CREATING = "Creating" |
@@ -463,6 +462,24 @@ def cancel(self, job_id: str): |
463 | 462 | response.text, |
464 | 463 | ) |
465 | 464 |
|
| 465 | + def _setup_launcher(self): |
| 466 | + super()._setup_launcher() |
| 467 | + launcher = self.launcher |
| 468 | + if launcher and isinstance(launcher, (FaultTolerance, Torchrun)): |
| 469 | + self.torchrun_nproc_per_node = self.nprocs_per_node |
| 470 | + self.ntasks_per_node = 1 |
| 471 | + CONSOLE.log( |
| 472 | + f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}" |
| 473 | + ) |
| 474 | + |
| 475 | + if launcher and isinstance(launcher, FaultTolerance): |
| 476 | + base_dir = os.path.join(self.job_dir, Path(self.job_dir).name) |
| 477 | + launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml") |
| 478 | + launcher.finished_flag_file = os.path.join( |
| 479 | + "/", RUNDIR_NAME, f"{self.job_name}_finished_flag" |
| 480 | + ) |
| 481 | + launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results") |
| 482 | + |
466 | 483 | def cleanup(self, handle: str): ... |
467 | 484 |
|
468 | 485 | def assign( |
|
0 commit comments