Skip to content

Commit 27bccfa

Browse files
authored
Fix git packager for git repo with submodules (#109)
1 parent 070fe56 commit 27bccfa

File tree

2 files changed

+82
-20
lines changed

2 files changed

+82
-20
lines changed

src/nemo_run/core/packaging/git.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class GitArchivePackager(Packager):
5757
#: Can be a branch name or a commit ref like HEAD.
5858
ref: str = "HEAD"
5959

60+
#: Include submodules in the archive.
61+
include_submodules: bool = True
62+
6063
#: Include extra files in the archive which matches include_pattern
6164
#: This str will be included in the command as: find {include_pattern} -type f to get the list of extra files to include in the archive
6265
include_pattern: str = ""
@@ -109,40 +112,42 @@ def package(self, path: Path, job_dir: str, name: str) -> str:
109112
), "Your repo has untracked files. Please track your files via git or set check_untracked_files to False to proceed with packaging."
110113

111114
ctx = Context()
115+
# we first add git files into an uncompressed archive
116+
# then we add submodule files into that archive
117+
# then we add an extra files from pattern to that archive
118+
# finally we compress it (cannot compress right away, since adding files is not possible)
119+
git_archive_cmd = (
120+
f"git archive --format=tar --output={output_file}.tmp {self.ref}:{git_sub_path}"
121+
)
122+
git_submodule_cmd = f"""git submodule foreach --recursive \
123+
'git archive --format=tar --prefix=$sm_path/ --output=$sha1.tmp HEAD && tar -Af {output_file}.tmp $sha1.tmp && rm $sha1.tmp'"""
124+
with ctx.cd(git_base_path):
125+
ctx.run(git_archive_cmd)
126+
if self.include_submodules:
127+
ctx.run(git_submodule_cmd)
128+
112129
if self.include_pattern:
113130
include_pattern_relative_path = self.include_pattern_relative_path or shlex.quote(
114131
str(git_base_path)
115132
)
116133
relative_include_pattern = os.path.relpath(
117134
self.include_pattern, include_pattern_relative_path
118135
)
119-
# we first add git files into an uncompressed archive
120-
# then we add an extra files from pattern to that archive
121-
# finally we compress it (cannot compress right away, since adding files is not possible)
122-
git_archive_cmd = (
123-
f"git archive --format=tar --output={output_file}.tmp {self.ref}:{git_sub_path}"
124-
)
125136
include_pattern_cmd = f"find {relative_include_pattern} -type f | tar -cf {os.path.join(git_base_path, 'additional.tmp')} -T -"
126-
tar_concatenate_cmd = f"tar -Af {output_file}.tmp additional.tmp"
127-
gzip_cmd = f"gzip -c {output_file}.tmp > {output_file}"
128-
rm_cmd = f"rm {output_file}.tmp additional.tmp"
129-
130-
with ctx.cd(git_base_path):
131-
ctx.run(git_archive_cmd)
137+
tar_concatenate_cmd = f"tar -Af {output_file}.tmp additional.tmp && rm additional.tmp"
132138

133139
with ctx.cd(include_pattern_relative_path):
134140
ctx.run(include_pattern_cmd)
135141

136142
with ctx.cd(git_base_path):
137143
ctx.run(tar_concatenate_cmd)
138-
ctx.run(gzip_cmd)
139-
ctx.run(rm_cmd)
140-
else:
141-
with ctx.cd(git_base_path):
142-
git_archive_cmd = (
143-
f"git archive --format=tar.gz --output={output_file} {self.ref}:{git_sub_path}"
144-
)
145-
ctx.run(git_archive_cmd)
144+
145+
gzip_cmd = f"gzip -c {output_file}.tmp > {output_file}"
146+
rm_cmd = f"rm {output_file}.tmp"
147+
148+
with ctx.cd(git_base_path):
149+
ctx.run(gzip_cmd)
150+
ctx.run(rm_cmd)
146151

147152
return output_file
148153

test/core/packaging/test_git.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,60 @@ def test_untracked_files_raises_exception(temp_repo):
290290
f.write("Untracked file")
291291
with pytest.raises(AssertionError, match="Your repo has untracked files"):
292292
packager.package(temp_repo, str(temp_repo), "test")
293+
294+
295+
@patch("nemo_run.core.packaging.git.Context", MockContext)
296+
def test_package_with_include_submodules(packager, temp_repo):
297+
temp_repo = Path(temp_repo)
298+
# Create a submodule
299+
submodule_path = temp_repo / "submodule"
300+
submodule_path.mkdir()
301+
os.chdir(str(submodule_path))
302+
subprocess.check_call(["git", "init", "--initial-branch=main"])
303+
open("submodule_file.txt", "w").write("Submodule file")
304+
subprocess.check_call(["git", "add", "."])
305+
subprocess.check_call(["git", "commit", "-m", "Initial submodule commit"])
306+
os.chdir(str(temp_repo))
307+
subprocess.check_call(["git", "submodule", "add", str(submodule_path)])
308+
subprocess.check_call(["git", "commit", "-m", "Add submodule"])
309+
310+
packager = GitArchivePackager(ref="HEAD", include_submodules=True)
311+
with tempfile.TemporaryDirectory() as job_dir:
312+
output_file = packager.package(Path(temp_repo), job_dir, "test_package")
313+
assert os.path.exists(output_file)
314+
subprocess.check_call(shlex.split(f"mkdir -p {os.path.join(job_dir, 'extracted_output')}"))
315+
subprocess.check_call(
316+
shlex.split(f"tar -xvzf {output_file} -C {os.path.join(job_dir, 'extracted_output')}"),
317+
)
318+
cmp = filecmp.dircmp(
319+
os.path.join(temp_repo, "submodule"),
320+
os.path.join(job_dir, "extracted_output", "submodule"),
321+
)
322+
assert cmp.left_list == cmp.right_list
323+
assert not cmp.diff_files
324+
325+
326+
@patch("nemo_run.core.packaging.git.Context", MockContext)
327+
def test_package_without_include_submodules(packager, temp_repo):
328+
temp_repo = Path(temp_repo)
329+
# Create a submodule
330+
submodule_path = temp_repo / "submodule"
331+
submodule_path.mkdir()
332+
os.chdir(str(submodule_path))
333+
subprocess.check_call(["git", "init", "--initial-branch=main"])
334+
open("submodule_file.txt", "w").write("Submodule file")
335+
subprocess.check_call(["git", "add", "."])
336+
subprocess.check_call(["git", "commit", "-m", "Initial submodule commit"])
337+
os.chdir(str(temp_repo))
338+
subprocess.check_call(["git", "submodule", "add", str(submodule_path)])
339+
subprocess.check_call(["git", "commit", "-m", "Add submodule"])
340+
341+
packager = GitArchivePackager(ref="HEAD", include_submodules=False)
342+
with tempfile.TemporaryDirectory() as job_dir:
343+
output_file = packager.package(Path(temp_repo), job_dir, "test_package")
344+
assert os.path.exists(output_file)
345+
subprocess.check_call(shlex.split(f"mkdir -p {os.path.join(job_dir, 'extracted_output')}"))
346+
subprocess.check_call(
347+
shlex.split(f"tar -xvzf {output_file} -C {os.path.join(job_dir, 'extracted_output')}"),
348+
)
349+
assert len(os.listdir(os.path.join(job_dir, "extracted_output", "submodule"))) == 0

0 commit comments

Comments
 (0)