Skip to content

Commit 0b0c81a

Browse files
committed
Add support for docker
Signed-off-by: Hemil Desai <[email protected]>
1 parent f790429 commit 0b0c81a

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/nemo_run/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
RECURSIVE_TYPES = (typing.Union, typing.Optional)
4747
NEMORUN_HOME = os.environ.get("NEMORUN_HOME", os.path.expanduser("~/.nemo_run"))
4848
RUNDIR_NAME = "nemo_run"
49+
RUNDIR_SPECIAL_NAME = "/$nemo_run"
4950
SCRIPTS_DIR = "scripts"
5051

5152

src/nemo_run/core/execution/slurm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from rich.console import Console
3030
from rich.text import Text
3131

32-
from nemo_run.config import RUNDIR_NAME
32+
from nemo_run.config import RUNDIR_NAME, RUNDIR_SPECIAL_NAME
3333
from nemo_run.core.execution.base import (
3434
Executor,
3535
ExecutorMacros,
@@ -866,8 +866,8 @@ def get_container_flags(
866866

867867
new_mounts = copy.deepcopy(base_mounts)
868868
for i, mount in enumerate(new_mounts):
869-
if mount.startswith("/$nemo_run"):
870-
new_mounts[i] = mount.replace("$nemo_run", src_job_dir)
869+
if mount.startswith(RUNDIR_SPECIAL_NAME):
870+
new_mounts[i] = mount.replace(RUNDIR_SPECIAL_NAME, src_job_dir, 1)
871871

872872
new_mounts.append(f"{src_job_dir}:/{RUNDIR_NAME}")
873873
_mount_arg = ",".join(new_mounts)

src/nemo_run/run/torchx_backend/schedulers/docker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
is_terminal,
4343
)
4444

45-
from nemo_run.config import RUNDIR_NAME
45+
from nemo_run.config import RUNDIR_NAME, RUNDIR_SPECIAL_NAME
4646
from nemo_run.core.execution.base import Executor
4747
from nemo_run.core.execution.docker import (
4848
DockerContainer,
@@ -120,6 +120,11 @@ def schedule(self, dryrun_info: AppDryRunInfo[DockerJobRequest]) -> str: # type
120120
log.warning(f"failed to pull image {image}, falling back to local: {e}")
121121

122122
for container in req.containers:
123+
for i, mount in enumerate(container.executor.volumes):
124+
if mount.startswith(RUNDIR_SPECIAL_NAME):
125+
container.executor.volumes[i] = mount.replace(
126+
RUNDIR_SPECIAL_NAME, req.executor.job_dir, 1
127+
)
123128
container.executor.volumes.append(f"{req.executor.job_dir}:/{RUNDIR_NAME}")
124129

125130
req.run(client=client)

0 commit comments

Comments
 (0)