Skip to content

Commit 4e15269

Browse files
committed
fix
Signed-off-by: oliver könig <[email protected]>
1 parent f259248 commit 4e15269

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

nemo_run/core/execution/dgxcloud.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import base64
1717
import glob
1818
import json
19-
import logging
2019
import os
2120
import subprocess
2221
import tempfile
@@ -30,14 +29,14 @@
3029
from invoke.context import Context
3130

3231
from nemo_run.config import get_nemorun_home
32+
from nemo_run.core.console import CONSOLE
33+
from nemo_run.core.constants import RUNDIR_NAME
3334
from nemo_run.core.execution.base import Executor, ExecutorMacros
34-
from nemo_run.core.execution.launcher import FaultTolerance, Launcher
35+
from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun
3536
from nemo_run.core.execution.utils import fill_template
3637
from nemo_run.core.packaging.base import Packager
3738
from nemo_run.core.packaging.git import GitArchivePackager
3839

39-
logger = logging.getLogger(__name__)
40-
4140

4241
class DGXCloudState(Enum):
4342
CREATING = "Creating"
@@ -463,6 +462,24 @@ def cancel(self, job_id: str):
463462
response.text,
464463
)
465464

465+
def _setup_launcher(self):
466+
super()._setup_launcher()
467+
launcher = self.launcher
468+
if launcher and isinstance(launcher, (FaultTolerance, Torchrun)):
469+
self.torchrun_nproc_per_node = self.nprocs_per_node
470+
self.ntasks_per_node = 1
471+
CONSOLE.log(
472+
f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}"
473+
)
474+
475+
if launcher and isinstance(launcher, FaultTolerance):
476+
base_dir = os.path.join(self.job_dir, Path(self.job_dir).name)
477+
launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml")
478+
launcher.finished_flag_file = os.path.join(
479+
"/", RUNDIR_NAME, f"{self.job_name}_finished_flag"
480+
)
481+
launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results")
482+
466483
def cleanup(self, handle: str): ...
467484

468485
def assign(

0 commit comments

Comments
 (0)