Skip to content

Commit dd013a7

Browse files
authored
Fix dependencies in Slurm Executor and local inline scripts (#93)
* Fix dependencies in Slurm Executor Signed-off-by: Hemil Desai <[email protected]> * Optimize Signed-off-by: Hemil Desai <[email protected]> * Fix local inline scripts Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent 9fc7481 commit dd013a7

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

src/nemo_run/config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,15 +425,19 @@ def get_name(self):
425425
return os.path.basename(self.path)
426426

427427
def to_command(
428-
self, with_entrypoint: bool = False, filename: Optional[str] = None
428+
self, with_entrypoint: bool = False, filename: Optional[str] = None, is_local: bool = False
429429
) -> list[str]:
430430
if self.inline:
431431
if filename:
432432
os.makedirs(os.path.dirname(filename), exist_ok=True)
433433
with open(filename, "w") as f:
434434
f.write("#!/usr/bin/bash\n" + self.inline)
435435

436-
cmd = [os.path.join(f"/{RUNDIR_NAME}", SCRIPTS_DIR, Path(filename).name)]
436+
if is_local:
437+
cmd = [filename]
438+
else:
439+
cmd = [os.path.join(f"/{RUNDIR_NAME}", SCRIPTS_DIR, Path(filename).name)]
440+
437441
if with_entrypoint:
438442
cmd = [self.entrypoint] + cmd
439443

src/nemo_run/run/torchx_backend/packaging.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from nemo_run.config import SCRIPTS_DIR, Partial, Script
2525
from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun
26+
from nemo_run.core.execution.local import LocalExecutor
2627
from nemo_run.core.serialization.yaml import YamlSerializer
2728
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
2829
from nemo_run.run.torchx_backend.components import ft_launcher, torchrun
@@ -119,7 +120,8 @@ def package(
119120

120121
args = fn_or_script.args
121122
role_args = fn_or_script.to_command(
122-
filename=os.path.join(executor.job_dir, SCRIPTS_DIR, f"{name}.sh")
123+
filename=os.path.join(executor.job_dir, SCRIPTS_DIR, f"{name}.sh"),
124+
is_local=True if isinstance(executor, LocalExecutor) else False,
123125
)
124126
m = fn_or_script.path if fn_or_script.m else None
125127
no_python = fn_or_script.entrypoint != "python"

src/nemo_run/run/torchx_backend/schedulers/slurm.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,24 +144,31 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
144144
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: # type: ignore
145145
# Setup
146146
req = dryrun_info.request
147-
slurm_cfg = dryrun_info.request.slurm_config
148-
assert slurm_cfg.experiment_id, "Executor not assigned to experiment."
147+
slurm_executor = dryrun_info.request.slurm_config
148+
assert slurm_executor.experiment_id, "Executor not assigned to experiment."
149149

150-
job_dir = slurm_cfg.job_dir
151-
tunnel = slurm_cfg.tunnel
150+
job_dir = slurm_executor.job_dir
151+
tunnel = slurm_executor.tunnel
152152
assert tunnel, f"Tunnel required for {self.__class__}"
153153
assert job_dir, f"Need to provide job_dir for {self.__class__}"
154154

155155
self._initialize_tunnel(tunnel)
156156
assert self.tunnel, f"Cannot initialize tunnel {tunnel}"
157157

158-
dst_path = os.path.join(self.tunnel.job_dir, f"{slurm_cfg.job_name}_sbatch.sh")
158+
dst_path = os.path.join(self.tunnel.job_dir, f"{slurm_executor.job_name}_sbatch.sh")
159+
160+
if slurm_executor.dependencies:
161+
cmd = ["sbatch", "--requeue", "--parsable"]
162+
slurm_deps = slurm_executor.parse_deps()
163+
cmd.append(f"--dependency={slurm_executor.dependency_type}:{':'.join(slurm_deps)}")
164+
req.cmd = cmd
165+
159166
# Run sbatch script
160167
req.cmd += [dst_path]
161168
job_id = self.tunnel.run(" ".join(req.cmd)).stdout.strip()
162169

163170
# Save metadata
164-
_save_job_dir(job_id, job_dir, tunnel, slurm_cfg.job_details.ls_term)
171+
_save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term)
165172
return job_id
166173

167174
def _cancel_existing(self, app_id: str) -> None:

0 commit comments

Comments
 (0)