Skip to content

Commit 6236229

Browse files
committed
Sync job code in local tunnel for Slurm Ray job
Signed-off-by: Hemil Desai <[email protected]>
1 parent 75bc3f5 commit 6236229

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

nemo_run/run/ray/slurm.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,13 +1106,14 @@ def start(
11061106
# ------------------------------------------------------------------
11071107
# Ship *workdir* over to the remote side (or package via packager)
11081108
# ------------------------------------------------------------------
1109+
cluster_dir = os.path.join(self.executor.tunnel.job_dir, self.name)
11091110
remote_workdir: Optional[str] = None
11101111

11111112
if workdir:
1112-
if isinstance(self.executor.tunnel, SSHTunnel):
1113-
# Rsync workdir honouring .gitignore
1114-
remote_workdir = os.path.join(self.executor.tunnel.job_dir, self.name, "code")
1115-
if not dryrun:
1113+
remote_workdir = os.path.join(cluster_dir, "code")
1114+
if not dryrun:
1115+
if isinstance(self.executor.tunnel, SSHTunnel):
1116+
# Rsync workdir honouring .gitignore
11161117
self.executor.tunnel.connect()
11171118
assert self.executor.tunnel.session is not None, (
11181119
"Tunnel session is not connected"
@@ -1123,11 +1124,22 @@ def start(
11231124
remote_workdir,
11241125
rsync_opts="--filter=':- .gitignore'",
11251126
)
1126-
else:
1127-
remote_workdir = workdir
1127+
else:
1128+
os.makedirs(remote_workdir, exist_ok=True)
1129+
subprocess.run(
1130+
[
1131+
"rsync",
1132+
"-pthrvz",
1133+
"--filter=:- .gitignore",
1134+
f"{os.path.join(workdir, '')}",
1135+
remote_workdir,
1136+
],
1137+
check=True,
1138+
)
11281139
elif self.executor.packager is not None:
11291140
# Use the packager to create an archive which we then extract on the
11301141
# submission host and optionally rsync to the target.
1142+
remote_workdir = os.path.join(cluster_dir, "code")
11311143
if not dryrun:
11321144
if isinstance(self.executor.tunnel, SSHTunnel):
11331145
package_dir = tempfile.mkdtemp(prefix="nemo_packager_")
@@ -1157,7 +1169,6 @@ def start(
11571169
)
11581170

11591171
if isinstance(self.executor.tunnel, SSHTunnel):
1160-
remote_workdir = os.path.join(self.executor.tunnel.job_dir, self.name, "code")
11611172
self.executor.tunnel.connect()
11621173
assert self.executor.tunnel.session is not None, (
11631174
"Tunnel session is not connected"
@@ -1169,7 +1180,17 @@ def start(
11691180
rsync_opts="--filter=':- .gitignore'",
11701181
)
11711182
else:
1172-
remote_workdir = local_code_extraction_path
1183+
os.makedirs(remote_workdir, exist_ok=True)
1184+
subprocess.run(
1185+
[
1186+
"rsync",
1187+
"-pthrvz",
1188+
"--filter=:- .gitignore",
1189+
f"{os.path.join(local_code_extraction_path, '')}",
1190+
remote_workdir,
1191+
],
1192+
check=True,
1193+
)
11731194

11741195
assert remote_workdir is not None, "workdir could not be determined"
11751196

0 commit comments

Comments
 (0)