|
42 | 42 | from nemo_run.core.packaging.base import Packager |
43 | 43 | from nemo_run.core.packaging.git import GitArchivePackager |
44 | 44 | from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer |
45 | | -from nemo_run.core.tunnel.callback import Callback |
46 | | -from nemo_run.core.tunnel.client import LocalTunnel, SSHConfigFile, SSHTunnel, Tunnel |
| 45 | +from nemo_run.core.tunnel.client import ( |
| 46 | + Callback, |
| 47 | + LocalTunnel, |
| 48 | + PackagingJob, |
| 49 | + SSHConfigFile, |
| 50 | + SSHTunnel, |
| 51 | + Tunnel, |
| 52 | +) |
47 | 53 | from nemo_run.core.tunnel.server import TunnelMetadata, server_dir |
48 | 54 | from nemo_run.devspace.base import DevSpace |
49 | 55 |
|
@@ -388,7 +394,7 @@ def __post_init__(self): |
388 | 394 | self.wait_time_for_group_job = 0 |
389 | 395 |
|
390 | 396 | def info(self) -> str: |
391 | | - return f"{self.__class__.__qualname__} on {self.tunnel._key}" |
| 397 | + return f"{self.__class__.__qualname__} on {self.tunnel.key}" |
392 | 398 |
|
393 | 399 | def alloc(self, job_name="interactive"): |
394 | 400 | self.job_name = f"{self.job_name_prefix}{job_name}" |
@@ -537,13 +543,39 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]: |
537 | 543 | return filenames |
538 | 544 |
|
539 | 545 | def package(self, packager: Packager, job_name: str): |
540 | | - if job_name in self.tunnel.packaging_jobs: |
| 546 | + if job_name in self.tunnel.packaging_jobs and not packager.symlink_from_remote_dir: |
541 | 547 | logger.info( |
542 | 548 | f"Packaging for job {job_name} in tunnel {self.tunnel} already done. Skipping subsequent packagings.\n" |
543 | 549 | "This may cause issues if you have multiple tasks with the same name but different packagers, as only the first packager will be used." |
544 | 550 | ) |
545 | 551 | return |
546 | 552 |
|
| 553 | + if packager.symlink_from_remote_dir: |
| 554 | + logger.info( |
| 555 | + f"Packager {packager} is configured to symlink from remote dir. Skipping packaging." |
| 556 | + ) |
| 557 | + if type(packager) is Packager: |
| 558 | + self.tunnel.packaging_jobs[job_name] = PackagingJob(symlink=False) |
| 559 | + return |
| 560 | + |
| 561 | + self.tunnel.packaging_jobs[job_name] = PackagingJob( |
| 562 | + symlink=True, |
| 563 | + src_path=packager.symlink_from_remote_dir, |
| 564 | + dst_path=os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"), |
| 565 | + ) |
| 566 | + |
| 567 | + # Tunnel job dir is the directory of the experiment id, so the base job dir is two levels up |
| 568 | + base_remote_dir = str(Path(self.tunnel.job_dir).parent.parent) |
| 569 | + base_remote_mount = f"{base_remote_dir}:{base_remote_dir}" |
| 570 | + if base_remote_mount not in self.container_mounts: |
| 571 | + self.container_mounts.append(f"{base_remote_dir}:{base_remote_dir}") |
| 572 | + |
| 573 | + for req in self.resource_group: |
| 574 | + if base_remote_mount not in req.container_mounts: |
| 575 | + req.container_mounts.append(base_remote_mount) |
| 576 | + |
| 577 | + return |
| 578 | + |
547 | 579 | assert self.experiment_id, "Executor not assigned to an experiment." |
548 | 580 | if isinstance(packager, GitArchivePackager): |
549 | 581 | output = subprocess.run( |
@@ -573,7 +605,12 @@ def package(self, packager: Packager, job_name: str): |
573 | 605 | f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True |
574 | 606 | ) |
575 | 607 |
|
576 | | - self.tunnel.packaging_jobs.add(job_name) |
| 608 | + self.tunnel.packaging_jobs[job_name] = PackagingJob( |
| 609 | + symlink=False, |
| 610 | + dst_path=None |
| 611 | + if type(packager) is Packager |
| 612 | + else os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"), |
| 613 | + ) |
577 | 614 |
|
578 | 615 | def parse_deps(self) -> list[str]: |
579 | 616 | """ |
|
0 commit comments