Skip to content

Commit caf35f5

Browse files
committed
Add option to symlink from remote dir in packager
Signed-off-by: Hemil Desai <[email protected]>
1 parent 5b40e95 commit caf35f5

File tree

8 files changed

+58
-19
lines changed

8 files changed

+58
-19
lines changed

src/nemo_run/core/execution/slurm.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from nemo_run.core.packaging.git import GitArchivePackager
4444
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
4545
from nemo_run.core.tunnel.callback import Callback
46-
from nemo_run.core.tunnel.client import LocalTunnel, SSHConfigFile, SSHTunnel, Tunnel
46+
from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHConfigFile, SSHTunnel, Tunnel
4747
from nemo_run.core.tunnel.server import TunnelMetadata, server_dir
4848
from nemo_run.devspace.base import DevSpace
4949

@@ -388,7 +388,7 @@ def __post_init__(self):
388388
self.wait_time_for_group_job = 0
389389

390390
def info(self) -> str:
391-
return f"{self.__class__.__qualname__} on {self.tunnel._key}"
391+
return f"{self.__class__.__qualname__} on {self.tunnel.key}"
392392

393393
def alloc(self, job_name="interactive"):
394394
self.job_name = f"{self.job_name_prefix}{job_name}"
@@ -544,6 +544,21 @@ def package(self, packager: Packager, job_name: str):
544544
)
545545
return
546546

547+
if packager.symlink_from_remote_dir:
548+
logger.info(
549+
f"Packager {packager} is configured to symlink from remote dir. Skipping packaging."
550+
)
551+
if type(packager) is Packager:
552+
self.tunnel.packaging_jobs[job_name] = PackagingJob(symlink=False)
553+
return
554+
555+
self.tunnel.packaging_jobs[job_name] = PackagingJob(
556+
symlink=True,
557+
src_path=packager.symlink_from_remote_dir,
558+
dst_path=os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
559+
)
560+
return
561+
547562
assert self.experiment_id, "Executor not assigned to an experiment."
548563
if isinstance(packager, GitArchivePackager):
549564
output = subprocess.run(
@@ -573,7 +588,12 @@ def package(self, packager: Packager, job_name: str):
573588
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
574589
)
575590

576-
self.tunnel.packaging_jobs.add(job_name)
591+
self.tunnel.packaging_jobs[job_name] = PackagingJob(
592+
symlink=False,
593+
dst_path=None
594+
if type(packager) is Packager
595+
else os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
596+
)
577597

578598
def parse_deps(self) -> list[str]:
579599
"""

src/nemo_run/core/packaging/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
from dataclasses import dataclass
1818
from pathlib import Path
19-
19+
from typing import Optional
2020

2121
from nemo_run.config import ConfigurableMixin
2222

@@ -45,6 +45,10 @@ class Packager(ConfigurableMixin):
4545
#: Uses component or executor specific debug flags if set to True.
4646
debug: bool = False
4747

48+
#: Symlinks the package from the provided remote dir.
49+
#: Only applicable when using SlurmExecutor at the moment.
50+
symlink_from_remote_dir: Optional[str] = None
51+
4852
def package(self, path: Path, job_dir: str, name: str) -> str: ...
4953

5054
def setup(self):

src/nemo_run/core/tunnel/client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,25 @@ def authentication_handler(title, instructions, prompt_list):
5757
return [getpass.getpass(str(pr[0])) for pr in prompt_list]
5858

5959

60+
@dataclass(kw_only=True)
61+
class PackagingJob:
62+
symlink: bool = False
63+
src_path: Optional[str] = None
64+
dst_path: Optional[str] = None
65+
66+
def symlink_cmd(self):
67+
return f"ln -s {self.src_path} {self.dst_path}"
68+
69+
6070
@dataclass(kw_only=True)
6171
class Tunnel(ABC):
6272
job_dir: str
6373
host: str
6474
user: str
6575

6676
def __post_init__(self):
67-
self._key = f"{self.user}@{self.host}"
68-
self._packaging_jobs = set()
77+
self.key = f"{self.user}@{self.host}"
78+
self._packaging_jobs: dict[str, PackagingJob] = {}
6979

7080
@property
7181
def packaging_jobs(self):

src/nemo_run/run/experiment.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,17 @@ def run(
645645
for tunnel in self.tunnels.values():
646646
if isinstance(tunnel, SSHTunnel):
647647
tunnel.connect()
648-
assert tunnel.session, f"SSH tunnel {tunnel._key} failed to connect."
648+
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
649649
rsync(tunnel.session, self._exp_dir, os.path.dirname(tunnel.job_dir))
650650

651+
symlink_cmds = []
652+
for packaging_job in tunnel.packaging_jobs.values():
653+
if packaging_job.symlink:
654+
symlink_cmds.append(packaging_job.symlink_cmd())
655+
656+
if symlink_cmds:
657+
tunnel.run(" && ".join(symlink_cmds))
658+
651659
return self._run_dag(detach=detach, tail_logs=tail_logs, executors=executors)
652660

653661
def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):

src/nemo_run/run/torchx_backend/packaging.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,6 @@ def package(
8282

8383
args.append(fn_or_script_filename)
8484
else:
85-
args += [
86-
"-p",
87-
_serialize(executor.packager.to_config()),
88-
]
89-
9085
args.append(_serialize(fn_or_script))
9186

9287
role_args = default_cmd + args

src/nemo_run/run/torchx_backend/schedulers/slurm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel):
7272
return
7373

7474
experiment = run_experiment._current_experiment.get(None)
75-
if experiment and tunnel._key in experiment.tunnels:
76-
self.tunnel = experiment.tunnels[tunnel._key]
75+
if experiment and tunnel.key in experiment.tunnels:
76+
self.tunnel = experiment.tunnels[tunnel.key]
7777
return
7878

7979
self.tunnel = tunnel
8080

8181
if experiment:
82-
experiment.tunnels[tunnel._key] = self.tunnel
82+
experiment.tunnels[tunnel.key] = self.tunnel
8383

8484
def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # type: ignore
8585
assert isinstance(cfg, SlurmExecutor), f"{cfg.__class__} not supported for slurm scheduler."

test/run/torchx_backend/test_packaging.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ def test_package_partial(mock_executor):
7171
"nemo_run.core.runners.fdl_runner",
7272
"-n",
7373
"test",
74-
"-p",
75-
"eJzdVE1Lw0AQ_StlL60goelRrKCePAgevJUSNtlJunazG_ajGEr_uzvbRNPYtBU9eQnZYea9N59bopWy5Ga0JbauwP8QDTm5HpE11PiqaLamBegkJjtvVekbZNbsA1wlwNs7wV8oVd3glIo5EUyp48JyadAqaRlsASMgcwsl4i4W5EkyeJ9w_M6nV-jOIHUFWS4RDyxl1FLvKp0QGMp4Zn-pAyEOZfS4_jTFY3l8JjL7B4lgOKOpHw-r6Qa08QPU-v2gUzlnTECUGJ1FmZI5L7p6HlqS15bjuZXSG6h7a_UEw-bjXCZKJ5kwYxysk-wSSpVoJz21hmi_CFwWUUoNdHW8NCtCdr4cB2RUF64EaRN89hkP96xdpmEMS4vTEM0aDCOsuLFK1-dBZh5kaGz-tk_YqM6JuXQwOq3psz3uLcMTEG5JrwYCaGDYUOHQkFNhIEizq4BAbvFO3kXNITpRnsNynlmE3REOj45mj5MpJ7_d2rh78OLu0YgvWby4X_AYydCTK4mKp9E08sI-AMy28Wg=",
7674
"eJzdlEtPwzAMgP_KlMuGhKp1RwRIcOOAxIHbNEVuk21haVKlzrRq2n8nDuto9wSxE5eqduLPdvxYM2ctsrvemmFdyvDDnJyy2x5byJok4Yui5iAET9kmqG32IXOsvix8qWXQt6y_MWW9BRVWeB1VmVcalalIa6CIusiIZIWyIO54zF6MkKuBou_D8IauA5vc9roHaTzI2GRCTiSCAIRgb7zWxBMqxz8GR4hubHu-rpr3sTx2iYz-QSJkLiALPYMOltJV0vHm3i8qNVVCaJnwyuVJbs1UzdrxPDdO3hsfr00oe132hOgGZPbQnxpuHc911aemOusdrcvnK55BvpBGJCgr5GUQYKZMJ5Dd5LBN7N2WO3AzX0iDnMR9n935a2bsNANhdh6xHYTThLmqQlb1ZcgoQE41znUrFfu-tXp-2htGFpY7b464ewOHCvSZLoC9F9ASIn4J2pMiDf8l4DxasnvanI9J2EwHL5tdAI2OgTICnXzdbjUuTNLmCD_QSR04ufXmYIOn7Y25E0Zb4eLkpgf1SskbXVXWUMjDZJiEyD4BcFMKTQ==",
7775
]
7876

test/test_api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from dataclasses import dataclass
1717
from unittest.mock import Mock
1818

19-
import nemo_run as run
2019
import pytest
20+
21+
import nemo_run as run
2122
from nemo_run.api import dryrun_fn
2223

2324

@@ -117,7 +118,10 @@ def test_dryrun_fn_with_executor(self, capsys, configured_fn):
117118

118119
captured = capsys.readouterr()
119120
assert "Dry run for task test.test_api:some_fn" in captured.out
120-
assert "LocalExecutor(packager=Packager(debug=False)" in captured.out
121+
assert (
122+
"LocalExecutor(packager=Packager(debug=False, symlink_from_remote_dir=None)"
123+
in captured.out
124+
)
121125

122126
def test_dryrun_fn_with_build(self, mocker, configured_fn):
123127
build_mock = Mock()

0 commit comments

Comments
 (0)