Skip to content

Commit d156258

Browse files
authored
Merge pull request #424 from dhellmann/patch-dir-lookup
standardize patch directories using override module names
2 parents 5a58514 + ccc3f01 commit d156258

File tree

5 files changed

+121
-173
lines changed

5 files changed

+121
-173
lines changed

e2e/test_bootstrap_extras.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ $OUTDIR/work-dir/setuptools-*/build.log
4141
4242
$OUTDIR/wheels-repo/downloads/PySocks-*.whl
4343
$OUTDIR/sdists-repo/downloads/PySocks-*.tar.gz
44-
$OUTDIR/sdists-repo/builds/PySocks-*.tar.gz
45-
$OUTDIR/work-dir/PySocks-*/build.log
44+
$OUTDIR/sdists-repo/builds/pysocks-*.tar.gz
45+
$OUTDIR/work-dir/pysocks-*/build.log
4646
"
4747

4848
for pattern in $EXPECTED_FILES; do

src/fromager/overrides.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import inspect
2+
import itertools
23
import logging
34
import pathlib
4-
import re
55
import typing
66
from importlib import metadata
77

8+
from packaging.requirements import Requirement
89
from packaging.utils import canonicalize_name
10+
from packaging.version import Version
911
from stevedore import extension
1012

1113
# An interface for reretrieving per-package information which influences
@@ -88,45 +90,39 @@ def log_overrides() -> None:
8890
)
8991

9092

91-
def patches_for_source_dir(
92-
patches_dir: pathlib.Path, source_dir_name: str
93+
def patches_for_requirement(
94+
patches_dir: pathlib.Path,
95+
req: Requirement,
96+
version: Version,
9397
) -> typing.Iterable[pathlib.Path]:
94-
"""Iterator producing patches to apply to the source dir.
98+
"""Iterator producing patches to apply to the source for a given version of a requirement.
9599
96-
Input should be the base directory name, not a full path.
97-
98-
Yields pathlib.Path() references to patches in the order they
99-
should be applied, which is controlled through lexical sorting of
100-
the filenames.
100+
Yields pathlib.Path() references to patches in the order they should be
101+
applied, which is controlled through lexical sorting of the filenames.
101102
102103
"""
103-
return sorted((patches_dir / source_dir_name).glob("*.patch"))
104-
105-
106-
def get_patch_directories(
107-
patches_dir: pathlib.Path, source_root_dir: pathlib.Path
108-
) -> list[pathlib.Path]:
104+
override_name = pkgname_to_override_module(req.name)
105+
unversioned_patch_dir = patches_dir / override_name
106+
versioned_patch_dir = patches_dir / f"{override_name}-{version}"
107+
return itertools.chain(
108+
# Apply all of the unversioned patches first, in order based on
109+
# filename.
110+
sorted(unversioned_patch_dir.glob("*.patch")),
111+
# Then apply any for this specific version, in order based on filename.
112+
sorted(versioned_patch_dir.glob("*.patch")),
113+
)
114+
115+
116+
def get_versioned_patch_directories(
117+
patches_dir: pathlib.Path,
118+
req: Requirement,
119+
) -> typing.Generator[pathlib.Path, None, None]:
109120
"""
110-
This function will return directories that may contain patches for a specific requirement.
111-
It takes in patches directory and a source root directory as input.
112-
The output will be a list of all directories containing patches for that requirement
121+
This function will return directories that may contain patches for any version of a specific requirement.
113122
"""
114123
# Get the req name as per the source_root_dir naming conventions
115-
req_name = source_root_dir.name.rsplit("-", 1)[0]
116-
patches = sorted((patches_dir).glob(f"{req_name}*"))
117-
filtered_patches = _filter_patches_based_on_req(patches, req_name)
118-
return filtered_patches
119-
120-
121-
# Helper method to filter the unwanted patches using a regex
122-
def _filter_patches_based_on_req(
123-
patches: list[pathlib.Path], req_name: str
124-
) -> list[pathlib.Path]:
125-
# Set up regex to filter out unwanted patches.
126-
pattern = re.compile(rf"^{req_name}-v?(\d+\.)+\d+")
127-
filtered_patches = [s for s in patches if pattern.match(s.name)]
128-
# filtered_patches won't contain patches for current version of req
129-
return filtered_patches
124+
override_name = pkgname_to_override_module(req.name)
125+
return patches_dir.glob(f"{override_name}-*")
130126

131127

132128
def pkgname_to_override_module(pkgname: str) -> str:

src/fromager/sources.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -213,21 +213,28 @@ def _takes_arg(f: typing.Callable, arg_name: str) -> bool:
213213

214214
def unpack_source(
215215
ctx: context.WorkContext,
216+
req: Requirement,
217+
version: Version,
216218
source_filename: pathlib.Path,
217219
) -> tuple[pathlib.Path, bool]:
218-
sdist_root_name = _sdist_root_name(source_filename)
219-
unpack_dir = ctx.work_dir / sdist_root_name
220+
# sdist names are less standardized and the names of the directories they
221+
# contain are also not very standard. Force the names into a predictable
222+
# form based on the override module name for the requirement.
223+
req_name = overrides.pkgname_to_override_module(req.name)
224+
expected_name = f"{req_name}-{version}"
225+
226+
# The unpack_dir is a parent dir where we put temporary outputs during the
227+
# build process, including the unpacked source in a subdirectory.
228+
unpack_dir = ctx.work_dir / expected_name
220229
if unpack_dir.exists():
221230
if ctx.cleanup:
222231
logger.debug("cleaning up %s", unpack_dir)
223232
shutil.rmtree(unpack_dir)
224233
else:
225234
logger.info("reusing %s", unpack_dir)
226235
return (unpack_dir / unpack_dir.name, False)
227-
# We create a unique directory based on the sdist name, but that
228-
# may not be the same name as the root directory of the content in
229-
# the sdist (due to case, punctuation, etc.), so after we unpack
230-
# it look for what was created.
236+
237+
# sdists might be tarballs or zip files.
231238
logger.debug("unpacking %s to %s", source_filename, unpack_dir)
232239
if str(source_filename).endswith(".tar.gz"):
233240
with tarfile.open(source_filename, "r") as t:
@@ -242,45 +249,51 @@ def unpack_source(
242249
else:
243250
raise ValueError(f"Do not know how to unpack source archive {source_filename}")
244251

245-
# if tarball named foo-2.3.1.tar.gz was downloaded, then ensure that after unpacking, the source directory's path is foo-2.3.1/foo-2.3.1
252+
# We create a unique directory based on the requirement name, but that may
253+
# not be the same name as the root directory of the content in the sdist
254+
# (due to case, punctuation, etc.), so after we unpack it look for what was
255+
# created and ensure the extracted directory matches the override module
256+
# name and version of the requirement.
246257
unpacked_root_dir = next(iter(unpack_dir.glob("*")))
247-
new_unpacked_root_dir = unpacked_root_dir.parent / sdist_root_name
248-
if unpacked_root_dir.name != new_unpacked_root_dir.name:
258+
if unpacked_root_dir.name != expected_name:
259+
desired_name = unpacked_root_dir.parent / expected_name
249260
try:
250-
shutil.move(str(unpacked_root_dir), str(new_unpacked_root_dir))
261+
shutil.move(
262+
str(unpacked_root_dir),
263+
str(desired_name),
264+
)
251265
except Exception as err:
252266
raise Exception(
253-
f"Could not rename {unpacked_root_dir.name} to {new_unpacked_root_dir.name}: {err}"
267+
f"Could not rename {unpacked_root_dir.name} to {desired_name}: {err}"
254268
) from err
269+
unpacked_root_dir = desired_name
255270

256-
return (new_unpacked_root_dir, True)
271+
return (unpacked_root_dir, True)
257272

258273

259274
def patch_source(
260275
ctx: context.WorkContext,
261276
source_root_dir: pathlib.Path,
262277
req: Requirement,
278+
version: Version,
263279
) -> None:
264-
# Flag to check whether patch has been applied
265-
has_applied = False
266-
# apply any unversioned patch first
267-
for p in overrides.patches_for_source_dir(
268-
ctx.settings.patches_dir, overrides.pkgname_to_override_module(req.name)
280+
patch_count = 0
281+
for p in overrides.patches_for_requirement(
282+
patches_dir=ctx.settings.patches_dir,
283+
req=req,
284+
version=version,
269285
):
270286
_apply_patch(p, source_root_dir)
271-
has_applied = True
272-
273-
# make sure that we don't apply the generic unversioned patch again
274-
if source_root_dir.name != overrides.pkgname_to_override_module(req.name):
275-
for p in overrides.patches_for_source_dir(
276-
ctx.settings.patches_dir, source_root_dir.name
277-
):
278-
_apply_patch(p, source_root_dir)
279-
has_applied = True
287+
patch_count += 1
280288

289+
logger.debug("%s: applied %d patches", req.name, patch_count)
281290
# If no patch has been applied, call warn for old patch
282-
if not has_applied:
283-
_warn_for_old_patch(source_root_dir, ctx.settings.patches_dir)
291+
if not patch_count:
292+
_warn_for_old_patch(
293+
req=req,
294+
version=version,
295+
patches_dir=ctx.settings.patches_dir,
296+
)
284297

285298

286299
def _apply_patch(patch: pathlib.Path, source_root_dir: pathlib.Path):
@@ -290,18 +303,20 @@ def _apply_patch(patch: pathlib.Path, source_root_dir: pathlib.Path):
290303

291304

292305
def _warn_for_old_patch(
293-
source_root_dir: pathlib.Path,
306+
req: Requirement,
307+
version: Version,
294308
patches_dir: pathlib.Path,
295309
) -> None:
296-
# Get the req name as per the source_root_dir naming conventions
297-
req_name = source_root_dir.name.rsplit("-", 1)[0]
298-
299310
# Filter the patch directories using regex
300-
patch_directories = overrides.get_patch_directories(patches_dir, source_root_dir)
311+
patch_directories = overrides.get_versioned_patch_directories(
312+
patches_dir=patches_dir, req=req
313+
)
301314

302315
for dirs in patch_directories:
303316
for p in dirs.iterdir():
304-
logger.warning(f"{req_name}: patch {p} exists but will not be applied")
317+
logger.warning(
318+
f"{req.name}: patch {p} exists but will not be applied for version {version}"
319+
)
305320

306321

307322
def write_build_meta(
@@ -366,7 +381,12 @@ def default_prepare_source(
366381
source_filename: pathlib.Path,
367382
version: Version,
368383
) -> tuple[pathlib.Path, bool]:
369-
source_root_dir, is_new = unpack_source(ctx, source_filename)
384+
source_root_dir, is_new = unpack_source(
385+
ctx=ctx,
386+
req=req,
387+
version=version,
388+
source_filename=source_filename,
389+
)
370390
if is_new:
371391
prepare_new_source(
372392
ctx=ctx,
@@ -387,7 +407,7 @@ def prepare_new_source(
387407
388408
`default_prepare_source` runs this function when the sources are new.
389409
"""
390-
patch_source(ctx, source_root_dir, req)
410+
patch_source(ctx, source_root_dir, req, version)
391411
pyproject.apply_project_override(
392412
ctx=ctx,
393413
req=req,

tests/test_overrides.py

Lines changed: 11 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,39 @@
11
import pathlib
2-
from importlib import metadata
32
from unittest import mock
43
from unittest.mock import patch
54

65
import pytest
6+
from packaging.requirements import Requirement
7+
from packaging.version import Version
78

89
from fromager import overrides
910

1011

11-
def test_patches_for_source_dir(tmp_path: pathlib.Path):
12+
def test_patches_for_requirement(tmp_path: pathlib.Path):
1213
patches_dir = tmp_path / "patches"
1314
patches_dir.mkdir()
1415

1516
project_patch_dir = patches_dir / "project-1.2.3"
1617
project_patch_dir.mkdir()
1718

18-
project_variant_patch_dir = patches_dir / "project-1.2.3-variant"
19-
project_variant_patch_dir.mkdir()
20-
2119
p1 = project_patch_dir / "001.patch"
2220
p2 = project_patch_dir / "002.patch"
2321
np1 = project_patch_dir / "not-a-patch.txt"
24-
p3 = project_variant_patch_dir / "003.patch"
25-
np2 = project_variant_patch_dir / "not-a-patch.txt"
2622

2723
# Create all of the test files
28-
for p in [p1, p2, p3]:
24+
for p in [p1, p2]:
2925
p.write_text("this is a patch file")
30-
for f in [np1, np2]:
26+
for f in [np1]:
3127
f.write_text("this is not a patch file")
3228

33-
results = list(overrides.patches_for_source_dir(patches_dir, "project-1.2.3"))
34-
assert results == [p1, p2]
35-
3629
results = list(
37-
overrides.patches_for_source_dir(patches_dir, "project-1.2.3-variant")
30+
overrides.patches_for_requirement(
31+
patches_dir=patches_dir,
32+
req=Requirement("project"),
33+
version=Version("1.2.3"),
34+
)
3835
)
39-
assert results == [p3]
36+
assert results == [p1, p2]
4037

4138

4239
def test_invoke_override_with_exact_args():
@@ -73,86 +70,3 @@ def default_foo(arg1):
7370
assert overrides.find_and_invoke(
7471
"pkg", "foo", default_foo, arg1="value1", arg2="value2"
7572
)
76-
77-
78-
def test_regex_dummy_package(tmp_path: pathlib.Path):
79-
req_name = "foo"
80-
patches_dir = tmp_path / "patches_dir"
81-
patches_dir.mkdir()
82-
83-
lst = [
84-
patches_dir / "foo-1.1.0",
85-
patches_dir / "foo-bar-2.0.0",
86-
patches_dir / "foo-v2.3.0",
87-
patches_dir / "foo-bar-bar-v2.3.1",
88-
patches_dir / "foo-bar-v5.5.5",
89-
patches_dir / "foo-3.4.4",
90-
patches_dir / "foo-v2.3.0.1",
91-
]
92-
93-
expected = [
94-
patches_dir / "foo-1.1.0",
95-
patches_dir / "foo-v2.3.0",
96-
patches_dir / "foo-3.4.4",
97-
patches_dir / "foo-v2.3.0.1",
98-
]
99-
100-
actual = overrides._filter_patches_based_on_req(lst, req_name)
101-
assert len(expected) == len(actual)
102-
assert expected == actual
103-
104-
105-
def test_regex_for_deepspeed(tmp_path: pathlib.Path):
106-
req_name = "deepspeed"
107-
patches_dir = tmp_path / "patches_dir"
108-
patches_dir.mkdir()
109-
110-
lst = [
111-
patches_dir / "deepspeed-1.1.0",
112-
patches_dir / "deepspeed-deep-2.0.0",
113-
patches_dir / "deepspeed-v2.3.0.post1",
114-
patches_dir / "deepspeed-v5.5.5",
115-
patches_dir / "deepspeed-3.4.4",
116-
patches_dir / "deepspeed-sdg-3.4.4",
117-
]
118-
119-
expected = [
120-
patches_dir / "deepspeed-1.1.0",
121-
patches_dir / "deepspeed-v2.3.0.post1",
122-
patches_dir / "deepspeed-v5.5.5",
123-
patches_dir / "deepspeed-3.4.4",
124-
]
125-
126-
actual = overrides._filter_patches_based_on_req(lst, req_name)
127-
assert len(expected) == len(actual)
128-
assert expected == actual
129-
130-
131-
def test_regex_for_vllm(tmp_path: pathlib.Path):
132-
req_name = "vllm"
133-
patches_dir = tmp_path / "patches_dir"
134-
patches_dir.mkdir()
135-
136-
lst = [
137-
patches_dir / "vllm-1.1.0.9",
138-
patches_dir / "vllm-llm-2.1.0.0",
139-
patches_dir / "vllm-v2.3.5.0.post1",
140-
patches_dir / "vllm-v5.5.5.1",
141-
]
142-
143-
expected = [
144-
patches_dir / "vllm-1.1.0.9",
145-
patches_dir / "vllm-v2.3.5.0.post1",
146-
patches_dir / "vllm-v5.5.5.1",
147-
]
148-
149-
actual = overrides._filter_patches_based_on_req(lst, req_name)
150-
assert len(expected) == len(actual)
151-
assert expected == actual
152-
153-
154-
def test_get_dist_info():
155-
fromager_version = metadata.version("fromager")
156-
plugin_dist, plugin_version = overrides._get_dist_info("fromager.submodule")
157-
assert plugin_dist == "fromager"
158-
assert plugin_version == fromager_version

0 commit comments

Comments
 (0)