Skip to content

Commit b4e2258

Browse files
authored
Add option to symlink from remote dir in packager (#122)
* Add option to symlink from remote dir in packager Signed-off-by: Hemil Desai <[email protected]> * Save tunnels for experiment Signed-off-by: Hemil Desai <[email protected]> * Fix Signed-off-by: Hemil Desai <[email protected]> * Mount base remote dir for symlinks Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent 283c0a1 commit b4e2258

File tree

10 files changed

+132
-79
lines changed

10 files changed

+132
-79
lines changed

src/nemo_run/core/execution/slurm.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,14 @@
4242
from nemo_run.core.packaging.base import Packager
4343
from nemo_run.core.packaging.git import GitArchivePackager
4444
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
45-
from nemo_run.core.tunnel.callback import Callback
46-
from nemo_run.core.tunnel.client import LocalTunnel, SSHConfigFile, SSHTunnel, Tunnel
45+
from nemo_run.core.tunnel.client import (
46+
Callback,
47+
LocalTunnel,
48+
PackagingJob,
49+
SSHConfigFile,
50+
SSHTunnel,
51+
Tunnel,
52+
)
4753
from nemo_run.core.tunnel.server import TunnelMetadata, server_dir
4854
from nemo_run.devspace.base import DevSpace
4955

@@ -388,7 +394,7 @@ def __post_init__(self):
388394
self.wait_time_for_group_job = 0
389395

390396
def info(self) -> str:
391-
return f"{self.__class__.__qualname__} on {self.tunnel._key}"
397+
return f"{self.__class__.__qualname__} on {self.tunnel.key}"
392398

393399
def alloc(self, job_name="interactive"):
394400
self.job_name = f"{self.job_name_prefix}{job_name}"
@@ -537,13 +543,39 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
537543
return filenames
538544

539545
def package(self, packager: Packager, job_name: str):
540-
if job_name in self.tunnel.packaging_jobs:
546+
if job_name in self.tunnel.packaging_jobs and not packager.symlink_from_remote_dir:
541547
logger.info(
542548
f"Packaging for job {job_name} in tunnel {self.tunnel} already done. Skipping subsequent packagings.\n"
543549
"This may cause issues if you have multiple tasks with the same name but different packagers, as only the first packager will be used."
544550
)
545551
return
546552

553+
if packager.symlink_from_remote_dir:
554+
logger.info(
555+
f"Packager {packager} is configured to symlink from remote dir. Skipping packaging."
556+
)
557+
if type(packager) is Packager:
558+
self.tunnel.packaging_jobs[job_name] = PackagingJob(symlink=False)
559+
return
560+
561+
self.tunnel.packaging_jobs[job_name] = PackagingJob(
562+
symlink=True,
563+
src_path=packager.symlink_from_remote_dir,
564+
dst_path=os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
565+
)
566+
567+
# Tunnel job dir is the directory of the experiment id, so the base job dir is two levels up
568+
base_remote_dir = str(Path(self.tunnel.job_dir).parent.parent)
569+
base_remote_mount = f"{base_remote_dir}:{base_remote_dir}"
570+
if base_remote_mount not in self.container_mounts:
571+
self.container_mounts.append(f"{base_remote_dir}:{base_remote_dir}")
572+
573+
for req in self.resource_group:
574+
if base_remote_mount not in req.container_mounts:
575+
req.container_mounts.append(base_remote_mount)
576+
577+
return
578+
547579
assert self.experiment_id, "Executor not assigned to an experiment."
548580
if isinstance(packager, GitArchivePackager):
549581
output = subprocess.run(
@@ -573,7 +605,12 @@ def package(self, packager: Packager, job_name: str):
573605
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
574606
)
575607

576-
self.tunnel.packaging_jobs.add(job_name)
608+
self.tunnel.packaging_jobs[job_name] = PackagingJob(
609+
symlink=False,
610+
dst_path=None
611+
if type(packager) is Packager
612+
else os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
613+
)
577614

578615
def parse_deps(self) -> list[str]:
579616
"""

src/nemo_run/core/packaging/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
from dataclasses import dataclass
1818
from pathlib import Path
19-
19+
from typing import Optional
2020

2121
from nemo_run.config import ConfigurableMixin
2222

@@ -45,6 +45,10 @@ class Packager(ConfigurableMixin):
4545
#: Uses component or executor specific debug flags if set to True.
4646
debug: bool = False
4747

48+
#: Symlinks the package from the provided remote dir.
49+
#: Only applicable when using SlurmExecutor at the moment.
50+
symlink_from_remote_dir: Optional[str] = None
51+
4852
def package(self, path: Path, job_dir: str, name: str) -> str: ...
4953

5054
def setup(self):

src/nemo_run/core/tunnel/callback.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/nemo_run/core/tunnel/client.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,17 @@
2424
from abc import ABC, abstractmethod
2525
from dataclasses import dataclass, field
2626
from pathlib import Path
27-
from typing import TYPE_CHECKING, Callable, Optional
27+
from typing import Callable, Optional
2828

2929
import paramiko
3030
import paramiko.ssh_exception
3131
from fabric import Config, Connection
3232
from invoke.context import Context
3333
from invoke.runners import Result as RunResult
3434

35-
from nemo_run.config import NEMORUN_HOME
35+
from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
3636
from nemo_run.core.frontend.console.api import CONSOLE
3737

38-
if TYPE_CHECKING:
39-
from nemo_run.core.tunnel.callback import Callback
40-
4138
logger: logging.Logger = logging.getLogger(__name__)
4239
TUNNEL_DIR = ".tunnels"
4340
TUNNEL_FILE_SUBPATH = os.path.join(NEMORUN_HOME, TUNNEL_DIR)
@@ -58,18 +55,24 @@ def authentication_handler(title, instructions, prompt_list):
5855

5956

6057
@dataclass(kw_only=True)
61-
class Tunnel(ABC):
58+
class PackagingJob(ConfigurableMixin):
59+
symlink: bool = False
60+
src_path: Optional[str] = None
61+
dst_path: Optional[str] = None
62+
63+
def symlink_cmd(self):
64+
return f"ln -s {self.src_path} {self.dst_path}"
65+
66+
67+
@dataclass(kw_only=True)
68+
class Tunnel(ABC, ConfigurableMixin):
6269
job_dir: str
6370
host: str
6471
user: str
72+
packaging_jobs: dict[str, PackagingJob] = field(default_factory=dict)
6573

6674
def __post_init__(self):
67-
self._key = f"{self.user}@{self.host}"
68-
self._packaging_jobs = set()
69-
70-
@property
71-
def packaging_jobs(self):
72-
return self._packaging_jobs
75+
self.key = f"{self.user}@{self.host}"
7376

7477
def _set_job_dir(self, experiment_id: str): ...
7578

@@ -377,3 +380,29 @@ def remove_entry(self, name: str):
377380
file.writelines(lines)
378381

379382
print(f"Removed SSH config entry for {host}.")
383+
384+
385+
class Callback:
386+
def setup(self, tunnel: "Tunnel"):
387+
"""Called when the tunnel is setup."""
388+
self.tunnel = tunnel
389+
390+
def on_start(self):
391+
"""Called when the keep_alive loop starts."""
392+
pass
393+
394+
def on_interval(self):
395+
"""Called at each interval during the keep_alive loop."""
396+
pass
397+
398+
def on_stop(self):
399+
"""Called when the keep_alive loop stops."""
400+
pass
401+
402+
def on_error(self, error: Exception):
403+
"""Called when an error occurs during the keep_alive loop.
404+
405+
Args:
406+
error (Exception): The exception that was raised.
407+
"""
408+
pass

src/nemo_run/devspace/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import fiddle as fdl
2020

2121
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
22-
from nemo_run.core.tunnel.callback import Callback
22+
from nemo_run.core.tunnel.client import Callback
2323

2424
if TYPE_CHECKING:
2525
from nemo_run.core.execution.base import Executor

src/nemo_run/run/experiment.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ class Experiment(ConfigurableMixin):
190190
_VERSION_FILE = "_VERSION"
191191
_TASK_FILE = "_TASKS"
192192
_DONE_FILE = "_DONE"
193+
_TUNNELS_FILE = "_TUNNELS"
193194
_current_experiment_token: Optional[contextvars.Token]
194195

195196
@classmethod
@@ -221,6 +222,12 @@ def _from_config(cls: Type["Experiment"], exp_dir: str) -> "Experiment":
221222

222223
exp: "Experiment" = fdl.build(cfg)
223224
exp._jobs = exp._load_jobs()
225+
try:
226+
exp.tunnels = exp._load_tunnels()
227+
except Exception as e:
228+
exp.console.log(
229+
f"Exception {e} loading tunnels for experiment {id}, will continue without loading tunnels."
230+
)
224231

225232
return exp
226233

@@ -327,6 +334,20 @@ def _save_config(self):
327334
with open(os.path.join(self._exp_dir, self.__class__._VERSION_FILE), "w+") as f:
328335
f.write(f"{run.__version__}\n")
329336

337+
def _save_tunnels(self):
338+
serializer = ZlibJSONSerializer()
339+
serialized_tunnels = {
340+
k: serializer.serialize(v.to_config()) for k, v in self.tunnels.items()
341+
}
342+
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE), "w+") as f:
343+
json.dump(serialized_tunnels, f)
344+
345+
def _load_tunnels(self) -> dict[str, Tunnel]:
346+
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE)) as f:
347+
serialized_tunnels = json.load(f)
348+
serializer = ZlibJSONSerializer()
349+
return {k: fdl.build(serializer.deserialize(v)) for k, v in serialized_tunnels.items()}
350+
330351
def _save_jobs(self):
331352
serialized_jobs = list(map(lambda job: job.serialize(), self.jobs))
332353
with open(os.path.join(self._exp_dir, self.__class__._TASK_FILE), "w+") as f:
@@ -645,9 +666,19 @@ def run(
645666
for tunnel in self.tunnels.values():
646667
if isinstance(tunnel, SSHTunnel):
647668
tunnel.connect()
648-
assert tunnel.session, f"SSH tunnel {tunnel._key} failed to connect."
669+
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
649670
rsync(tunnel.session, self._exp_dir, os.path.dirname(tunnel.job_dir))
650671

672+
symlink_cmds = []
673+
for packaging_job in tunnel.packaging_jobs.values():
674+
if packaging_job.symlink:
675+
symlink_cmds.append(packaging_job.symlink_cmd())
676+
677+
if symlink_cmds:
678+
tunnel.run(" && ".join(symlink_cmds))
679+
680+
self._save_tunnels()
681+
651682
return self._run_dag(detach=detach, tail_logs=tail_logs, executors=executors)
652683

653684
def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):

src/nemo_run/run/torchx_backend/packaging.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,6 @@ def package(
8282

8383
args.append(fn_or_script_filename)
8484
else:
85-
args += [
86-
"-p",
87-
_serialize(executor.packager.to_config()),
88-
]
89-
9085
args.append(_serialize(fn_or_script))
9186

9287
role_args = default_cmd + args

src/nemo_run/run/torchx_backend/schedulers/slurm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel):
7272
return
7373

7474
experiment = run_experiment._current_experiment.get(None)
75-
if experiment and tunnel._key in experiment.tunnels:
76-
self.tunnel = experiment.tunnels[tunnel._key]
75+
if experiment and tunnel.key in experiment.tunnels:
76+
self.tunnel = experiment.tunnels[tunnel.key]
7777
return
7878

7979
self.tunnel = tunnel
8080

8181
if experiment:
82-
experiment.tunnels[tunnel._key] = self.tunnel
82+
experiment.tunnels[tunnel.key] = self.tunnel
8383

8484
def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # type: ignore
8585
assert isinstance(cfg, SlurmExecutor), f"{cfg.__class__} not supported for slurm scheduler."
@@ -96,6 +96,8 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
9696
partition = executor.partition
9797
assert partition is None or isinstance(partition, str), "partition must be str"
9898

99+
executor.package(packager=executor.packager, job_name=Path(job_dir).name)
100+
99101
srun_cmds: list[list[str]] = []
100102
jobs = []
101103
envs = {}
@@ -137,8 +139,6 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
137139
with open(path, "w") as f:
138140
f.write(script)
139141

140-
executor.package(packager=executor.packager, job_name=Path(job_dir).name)
141-
142142
return AppDryRunInfo(req, repr)
143143

144144
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: # type: ignore

test/run/torchx_backend/test_packaging.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ def test_package_partial(mock_executor):
7171
"nemo_run.core.runners.fdl_runner",
7272
"-n",
7373
"test",
74-
"-p",
75-
"eJzdVE1Lw0AQ_StlL60goelRrKCePAgevJUSNtlJunazG_ajGEr_uzvbRNPYtBU9eQnZYea9N59bopWy5Ga0JbauwP8QDTm5HpE11PiqaLamBegkJjtvVekbZNbsA1wlwNs7wV8oVd3glIo5EUyp48JyadAqaRlsASMgcwsl4i4W5EkyeJ9w_M6nV-jOIHUFWS4RDyxl1FLvKp0QGMp4Zn-pAyEOZfS4_jTFY3l8JjL7B4lgOKOpHw-r6Qa08QPU-v2gUzlnTECUGJ1FmZI5L7p6HlqS15bjuZXSG6h7a_UEw-bjXCZKJ5kwYxysk-wSSpVoJz21hmi_CFwWUUoNdHW8NCtCdr4cB2RUF64EaRN89hkP96xdpmEMS4vTEM0aDCOsuLFK1-dBZh5kaGz-tk_YqM6JuXQwOq3psz3uLcMTEG5JrwYCaGDYUOHQkFNhIEizq4BAbvFO3kXNITpRnsNynlmE3REOj45mj5MpJ7_d2rh78OLu0YgvWby4X_AYydCTK4mKp9E08sI-AMy28Wg=",
7674
"eJzdlEtPwzAMgP_KlMuGhKp1RwRIcOOAxIHbNEVuk21haVKlzrRq2n8nDuto9wSxE5eqduLPdvxYM2ctsrvemmFdyvDDnJyy2x5byJok4Yui5iAET9kmqG32IXOsvix8qWXQt6y_MWW9BRVWeB1VmVcalalIa6CIusiIZIWyIO54zF6MkKuBou_D8IauA5vc9roHaTzI2GRCTiSCAIRgb7zWxBMqxz8GR4hubHu-rpr3sTx2iYz-QSJkLiALPYMOltJV0vHm3i8qNVVCaJnwyuVJbs1UzdrxPDdO3hsfr00oe132hOgGZPbQnxpuHc911aemOusdrcvnK55BvpBGJCgr5GUQYKZMJ5Dd5LBN7N2WO3AzX0iDnMR9n935a2bsNANhdh6xHYTThLmqQlb1ZcgoQE41znUrFfu-tXp-2htGFpY7b464ewOHCvSZLoC9F9ASIn4J2pMiDf8l4DxasnvanI9J2EwHL5tdAI2OgTICnXzdbjUuTNLmCD_QSR04ufXmYIOn7Y25E0Zb4eLkpgf1SskbXVXWUMjDZJiEyD4BcFMKTQ==",
7775
]
7876

test/test_api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from dataclasses import dataclass
1717
from unittest.mock import Mock
1818

19-
import nemo_run as run
2019
import pytest
20+
21+
import nemo_run as run
2122
from nemo_run.api import dryrun_fn
2223

2324

@@ -117,7 +118,10 @@ def test_dryrun_fn_with_executor(self, capsys, configured_fn):
117118

118119
captured = capsys.readouterr()
119120
assert "Dry run for task test.test_api:some_fn" in captured.out
120-
assert "LocalExecutor(packager=Packager(debug=False)" in captured.out
121+
assert (
122+
"LocalExecutor(packager=Packager(debug=False, symlink_from_remote_dir=None)"
123+
in captured.out
124+
)
121125

122126
def test_dryrun_fn_with_build(self, mocker, configured_fn):
123127
build_mock = Mock()

0 commit comments

Comments
 (0)