Skip to content

Commit 213c399

Browse files
Correctly append tar files for packaging (#317)
* Correctly append tar files for packaging Signed-off-by: Sahil Modi <[email protected]> * tests Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Sahil Modi <[email protected]> Signed-off-by: Hemil Desai <[email protected]> Co-authored-by: Hemil Desai <[email protected]>
1 parent 46d1dce commit 213c399

File tree

2 files changed

+148
-30
lines changed

2 files changed

+148
-30
lines changed

nemo_run/core/packaging/git.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import uuid
2121
from dataclasses import dataclass
2222
from pathlib import Path
23+
import re
2324

2425
from invoke.context import Context
2526

@@ -72,6 +73,38 @@ class GitArchivePackager(Packager):
7273
check_uncommitted_changes: bool = False
7374
check_untracked_files: bool = False
7475

76+
def _concatenate_tar_files(
77+
self, ctx: Context, output_file: str, files_to_concatenate: list[str]
78+
):
79+
"""Concatenate multiple uncompressed tar files into a single tar archive.
80+
81+
The list should include ALL fragments to merge (base + additions).
82+
Creates/overwrites `output_file`.
83+
"""
84+
if not files_to_concatenate:
85+
raise ValueError("files_to_concatenate must not be empty")
86+
87+
# Quote paths for shell safety
88+
quoted_files = [shlex.quote(f) for f in files_to_concatenate]
89+
quoted_output_file = shlex.quote(output_file)
90+
91+
if os.uname().sysname == "Linux":
92+
# Start from the first archive then append the rest, to avoid self-append issues
93+
first_file, *rest_files = quoted_files
94+
ctx.run(f"cp {first_file} {quoted_output_file}")
95+
if rest_files:
96+
ctx.run(f"tar Af {quoted_output_file} {' '.join(rest_files)}")
97+
# Remove all input fragments
98+
ctx.run(f"rm {' '.join(quoted_files)}")
99+
else:
100+
# Extract all fragments and repack once (faster than iterative extract/append)
101+
temp_dir = f"temp_extract_{uuid.uuid4()}"
102+
ctx.run(f"mkdir -p {temp_dir}")
103+
for file in quoted_files:
104+
ctx.run(f"tar xf {file} -C {temp_dir}")
105+
ctx.run(f"tar cf {quoted_output_file} -C {temp_dir} .")
106+
ctx.run(f"rm -r {temp_dir} {' '.join(quoted_files)}")
107+
75108
def package(self, path: Path, job_dir: str, name: str) -> str:
76109
output_file = os.path.join(job_dir, f"{name}.tar.gz")
77110
if os.path.exists(output_file):
@@ -113,20 +146,11 @@ def package(self, path: Path, job_dir: str, name: str) -> str:
113146
)
114147

115148
ctx = Context()
116-
# we first add git files into an uncompressed archive
117-
# then we add submodule files into that archive
118-
# then we add an extra files from pattern to that archive
119-
# finally we compress it (cannot compress right away, since adding files is not possible)
120-
git_archive_cmd = (
121-
f"git archive --format=tar --output={output_file}.tmp {self.ref}:{git_sub_path}"
122-
)
123-
if os.uname().sysname == "Linux":
124-
tar_submodule_cmd = f"tar Af {output_file}.tmp $sha1.tmp && rm $sha1.tmp"
125-
else:
126-
tar_submodule_cmd = f"cat $sha1.tmp >> {output_file}.tmp && rm $sha1.tmp"
127-
128-
git_submodule_cmd = f"""git submodule foreach --recursive \
129-
'git archive --format=tar --prefix=$sm_path/ --output=$sha1.tmp HEAD && {tar_submodule_cmd}'"""
149+
# Build the base uncompressed archive, then separately generate all additional fragments.
150+
# Finally, concatenate all fragments in one pass for performance and portability.
151+
base_tar_tmp = f"{output_file}.tmp.base"
152+
git_archive_cmd = f"git archive --format=tar --output={shlex.quote(base_tar_tmp)} {self.ref}:{git_sub_path}"
153+
git_submodule_cmd = "git submodule foreach --recursive 'git archive --format=tar --prefix=$sm_path/ --output=$sha1.tmp HEAD'"
130154

131155
with ctx.cd(git_base_path):
132156
ctx.run(git_archive_cmd)
@@ -143,6 +167,16 @@ def package(self, path: Path, job_dir: str, name: str) -> str:
143167
"include_pattern and include_pattern_relative_path should have the same length"
144168
)
145169

170+
# Collect submodule tar fragments (named as <40-hex-sha1>.tmp) if any
171+
submodule_tmp_files: list[str] = []
172+
if self.include_submodules:
173+
for dirpath, _dirnames, filenames in os.walk(git_base_path):
174+
for filename in filenames:
175+
if re.fullmatch(r"[0-9a-f]{40}\.tmp", filename):
176+
submodule_tmp_files.append(os.path.join(dirpath, filename))
177+
178+
# Generate additional fragments from include patterns and collect their paths
179+
additional_tmp_files: list[str] = []
146180
for include_pattern, include_pattern_relative_path in zip(
147181
self.include_pattern, self.include_pattern_relative_path
148182
):
@@ -158,26 +192,16 @@ def package(self, path: Path, job_dir: str, name: str) -> str:
158192
include_pattern, include_pattern_relative_path
159193
)
160194
pattern_tar_file_name = os.path.join(git_base_path, pattern_tar_file_name)
161-
include_pattern_cmd = (
162-
f"find {relative_include_pattern} -type f | tar -cf {pattern_tar_file_name} -T -"
163-
)
195+
include_pattern_cmd = f"find {relative_include_pattern} -type f | tar -cf {shlex.quote(pattern_tar_file_name)} -T -"
164196

165197
with ctx.cd(include_pattern_relative_path):
166198
ctx.run(include_pattern_cmd)
199+
additional_tmp_files.append(pattern_tar_file_name)
167200

168-
with ctx.cd(git_base_path):
169-
if os.uname().sysname == "Linux":
170-
# On Linux, directly concatenate tar files
171-
ctx.run(f"tar Af {output_file}.tmp {pattern_tar_file_name}")
172-
ctx.run(f"rm {pattern_tar_file_name}")
173-
else:
174-
# Extract and repack approach for other platforms
175-
temp_dir = f"temp_extract_{pattern_file_id}"
176-
ctx.run(f"mkdir -p {temp_dir}")
177-
ctx.run(f"tar xf {output_file}.tmp -C {temp_dir}")
178-
ctx.run(f"tar xf {pattern_tar_file_name} -C {temp_dir}")
179-
ctx.run(f"tar cf {output_file}.tmp -C {temp_dir} .")
180-
ctx.run(f"rm -rf {temp_dir} {pattern_tar_file_name}")
201+
# Concatenate all fragments in one pass into {output_file}.tmp
202+
fragments_to_merge: list[str] = [base_tar_tmp] + submodule_tmp_files + additional_tmp_files
203+
with ctx.cd(git_base_path):
204+
self._concatenate_tar_files(ctx, f"{output_file}.tmp", fragments_to_merge)
181205

182206
gzip_cmd = f"gzip -c {output_file}.tmp > {output_file}"
183207
rm_cmd = f"rm {output_file}.tmp"

test/core/packaging/test_git.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import os
1818
import shlex
1919
import subprocess
20+
import tarfile
2021
import tempfile
2122
from pathlib import Path
23+
from types import SimpleNamespace
2224
from unittest.mock import patch
2325

2426
import invoke
@@ -418,3 +420,95 @@ def test_package_without_include_submodules(packager, temp_repo):
418420
),
419421
)
420422
assert len(os.listdir(os.path.join(job_dir, "extracted_output", "submodule"))) == 0
423+
424+
425+
def _make_uncompressed_tar_from_dir(src_dir: Path, tar_path: Path):
426+
# Create an uncompressed tar at tar_path from the contents of src_dir
427+
# with files at the root of the archive
428+
with tarfile.open(tar_path, mode="w") as tf:
429+
for entry in sorted(src_dir.iterdir()):
430+
tf.add(entry, arcname=entry.name)
431+
432+
433+
@patch("nemo_run.core.packaging.git.Context", MockContext)
434+
def test_concatenate_tar_files_non_linux_integration(tmp_path, monkeypatch):
435+
# Force non-Linux path (extract+repack)
436+
monkeypatch.setattr(os, "uname", lambda: SimpleNamespace(sysname="Darwin"))
437+
438+
# Prepare two small tar fragments
439+
dir_a = tmp_path / "a"
440+
dir_b = tmp_path / "b"
441+
dir_a.mkdir()
442+
dir_b.mkdir()
443+
(dir_a / "fileA.txt").write_text("A")
444+
(dir_b / "fileB.txt").write_text("B")
445+
446+
tar_a = tmp_path / "a.tar"
447+
tar_b = tmp_path / "b.tar"
448+
_make_uncompressed_tar_from_dir(dir_a, tar_a)
449+
_make_uncompressed_tar_from_dir(dir_b, tar_b)
450+
451+
out_tar = tmp_path / "out.tar"
452+
packager = GitArchivePackager()
453+
ctx = MockContext()
454+
packager._concatenate_tar_files(ctx, str(out_tar), [str(tar_a), str(tar_b)])
455+
456+
# Inputs removed
457+
assert not tar_a.exists() and not tar_b.exists()
458+
459+
# Output contains both files at root
460+
assert out_tar.exists()
461+
with tarfile.open(out_tar, mode="r") as tf:
462+
names = sorted(m.name for m in tf.getmembers() if m.isfile())
463+
assert names == ["./fileA.txt", "./fileB.txt"]
464+
465+
466+
def test_concatenate_tar_files_linux_emits_expected_commands(monkeypatch, tmp_path):
467+
# Simulate Linux branch; use a dummy Context that records commands instead of executing
468+
monkeypatch.setattr(os, "uname", lambda: SimpleNamespace(sysname="Linux"))
469+
470+
class DummyContext:
471+
def __init__(self):
472+
self.commands: list[str] = []
473+
474+
def run(self, cmd: str, **_kwargs):
475+
self.commands.append(cmd)
476+
477+
# Support ctx.cd(...) context manager API
478+
def cd(self, _path: Path):
479+
class _CD:
480+
def __enter__(self_nonlocal):
481+
return self
482+
483+
def __exit__(self_nonlocal, exc_type, exc, tb):
484+
return False
485+
486+
return _CD()
487+
488+
# Fake inputs (do not need to exist since we don't execute)
489+
tar1 = str(tmp_path / "one.tar")
490+
tar2 = str(tmp_path / "two.tar")
491+
tar3 = str(tmp_path / "three.tar")
492+
out_tar = str(tmp_path / "out.tar")
493+
494+
ctx = DummyContext()
495+
packager = GitArchivePackager()
496+
packager._concatenate_tar_files(ctx, out_tar, [tar1, tar2, tar3])
497+
498+
# Expected sequence: cp first -> tar Af rest -> rm all inputs
499+
assert len(ctx.commands) == 3
500+
assert ctx.commands[0] == f"cp {shlex.quote(tar1)} {shlex.quote(out_tar)}"
501+
assert (
502+
ctx.commands[1] == f"tar Af {shlex.quote(out_tar)} {shlex.quote(tar2)} {shlex.quote(tar3)}"
503+
)
504+
assert ctx.commands[2] == f"rm {shlex.quote(tar1)} {shlex.quote(tar2)} {shlex.quote(tar3)}"
505+
506+
507+
@patch("nemo_run.core.packaging.git.Context", MockContext)
508+
def test_include_pattern_length_mismatch_raises(packager, temp_repo):
509+
# Mismatch between include_pattern and include_pattern_relative_path should raise
510+
packager.include_pattern = ["extra"]
511+
packager.include_pattern_relative_path = ["/tmp", "/also/tmp"]
512+
with tempfile.TemporaryDirectory() as job_dir:
513+
with pytest.raises(ValueError, match="same length"):
514+
packager.package(Path(temp_repo), job_dir, "mismatch")

0 commit comments

Comments
 (0)