Skip to content

Commit ccc3f01

Browse files
committed
look for patches in canonical locations
Require patch directories be named using the override module form of the name of the package being patched. In addition to being more predictable, this allows us to simplify the logic for detecting patches for other versions of a project that are not being applied.
1 parent c135814 commit ccc3f01

File tree

4 files changed

+82
-154
lines changed

4 files changed

+82
-154
lines changed

src/fromager/overrides.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import inspect
2+
import itertools
23
import logging
34
import pathlib
4-
import re
55
import typing
66
from importlib import metadata
77

8+
from packaging.requirements import Requirement
89
from packaging.utils import canonicalize_name
10+
from packaging.version import Version
911
from stevedore import extension
1012

1113
# An interface for reretrieving per-package information which influences
@@ -88,45 +90,39 @@ def log_overrides() -> None:
8890
)
8991

9092

91-
def patches_for_source_dir(
92-
patches_dir: pathlib.Path, source_dir_name: str
93+
def patches_for_requirement(
94+
patches_dir: pathlib.Path,
95+
req: Requirement,
96+
version: Version,
9397
) -> typing.Iterable[pathlib.Path]:
94-
"""Iterator producing patches to apply to the source dir.
98+
"""Iterator producing patches to apply to the source for a given version of a requirement.
9599
96-
Input should be the base directory name, not a full path.
97-
98-
Yields pathlib.Path() references to patches in the order they
99-
should be applied, which is controlled through lexical sorting of
100-
the filenames.
100+
Yields pathlib.Path() references to patches in the order they should be
101+
applied, which is controlled through lexical sorting of the filenames.
101102
102103
"""
103-
return sorted((patches_dir / source_dir_name).glob("*.patch"))
104-
105-
106-
def get_patch_directories(
107-
patches_dir: pathlib.Path, source_root_dir: pathlib.Path
108-
) -> list[pathlib.Path]:
104+
override_name = pkgname_to_override_module(req.name)
105+
unversioned_patch_dir = patches_dir / override_name
106+
versioned_patch_dir = patches_dir / f"{override_name}-{version}"
107+
return itertools.chain(
108+
# Apply all of the unversioned patches first, in order based on
109+
# filename.
110+
sorted(unversioned_patch_dir.glob("*.patch")),
111+
# Then apply any for this specific version, in order based on filename.
112+
sorted(versioned_patch_dir.glob("*.patch")),
113+
)
114+
115+
116+
def get_versioned_patch_directories(
117+
patches_dir: pathlib.Path,
118+
req: Requirement,
119+
) -> typing.Generator[pathlib.Path, None, None]:
109120
"""
110-
This function will return directories that may contain patches for a specific requirement.
111-
It takes in patches directory and a source root directory as input.
112-
The output will be a list of all directories containing patches for that requirement
121+
This function will return directories that may contain patches for any version of a specific requirement.
113122
"""
114123
# Get the req name as per the source_root_dir naming conventions
115-
req_name = source_root_dir.name.rsplit("-", 1)[0]
116-
patches = sorted((patches_dir).glob(f"{req_name}*"))
117-
filtered_patches = _filter_patches_based_on_req(patches, req_name)
118-
return filtered_patches
119-
120-
121-
# Helper method to filter the unwanted patches using a regex
122-
def _filter_patches_based_on_req(
123-
patches: list[pathlib.Path], req_name: str
124-
) -> list[pathlib.Path]:
125-
# Set up regex to filter out unwanted patches.
126-
pattern = re.compile(rf"^{req_name}-v?(\d+\.)+\d+")
127-
filtered_patches = [s for s in patches if pattern.match(s.name)]
128-
# filtered_patches won't contain patches for current version of req
129-
return filtered_patches
124+
override_name = pkgname_to_override_module(req.name)
125+
return patches_dir.glob(f"{override_name}-*")
130126

131127

132128
def pkgname_to_override_module(pkgname: str) -> str:

src/fromager/sources.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -275,27 +275,25 @@ def patch_source(
275275
ctx: context.WorkContext,
276276
source_root_dir: pathlib.Path,
277277
req: Requirement,
278+
version: Version,
278279
) -> None:
279280
patch_count = 0
280-
# apply any unversioned patch first
281-
for p in overrides.patches_for_source_dir(
282-
ctx.settings.patches_dir, overrides.pkgname_to_override_module(req.name)
281+
for p in overrides.patches_for_requirement(
282+
patches_dir=ctx.settings.patches_dir,
283+
req=req,
284+
version=version,
283285
):
284286
_apply_patch(p, source_root_dir)
285287
patch_count += 1
286288

287-
# make sure that we don't apply the generic unversioned patch again
288-
if source_root_dir.name != overrides.pkgname_to_override_module(req.name):
289-
for p in overrides.patches_for_source_dir(
290-
ctx.settings.patches_dir, source_root_dir.name
291-
):
292-
_apply_patch(p, source_root_dir)
293-
patch_count += 1
294-
295289
logger.debug("%s: applied %d patches", req.name, patch_count)
296290
# If no patch has been applied, call warn for old patch
297291
if not patch_count:
298-
_warn_for_old_patch(source_root_dir, ctx.settings.patches_dir)
292+
_warn_for_old_patch(
293+
req=req,
294+
version=version,
295+
patches_dir=ctx.settings.patches_dir,
296+
)
299297

300298

301299
def _apply_patch(patch: pathlib.Path, source_root_dir: pathlib.Path):
@@ -305,18 +303,20 @@ def _apply_patch(patch: pathlib.Path, source_root_dir: pathlib.Path):
305303

306304

307305
def _warn_for_old_patch(
308-
source_root_dir: pathlib.Path,
306+
req: Requirement,
307+
version: Version,
309308
patches_dir: pathlib.Path,
310309
) -> None:
311-
# Get the req name as per the source_root_dir naming conventions
312-
req_name = source_root_dir.name.rsplit("-", 1)[0]
313-
314310
# Filter the patch directories using regex
315-
patch_directories = overrides.get_patch_directories(patches_dir, source_root_dir)
311+
patch_directories = overrides.get_versioned_patch_directories(
312+
patches_dir=patches_dir, req=req
313+
)
316314

317315
for dirs in patch_directories:
318316
for p in dirs.iterdir():
319-
logger.warning(f"{req_name}: patch {p} exists but will not be applied")
317+
logger.warning(
318+
f"{req.name}: patch {p} exists but will not be applied for version {version}"
319+
)
320320

321321

322322
def write_build_meta(
@@ -407,7 +407,7 @@ def prepare_new_source(
407407
408408
`default_prepare_source` runs this function when the sources are new.
409409
"""
410-
patch_source(ctx, source_root_dir, req)
410+
patch_source(ctx, source_root_dir, req, version)
411411
pyproject.apply_project_override(
412412
ctx=ctx,
413413
req=req,

tests/test_overrides.py

Lines changed: 11 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,39 @@
11
import pathlib
2-
from importlib import metadata
32
from unittest import mock
43
from unittest.mock import patch
54

65
import pytest
6+
from packaging.requirements import Requirement
7+
from packaging.version import Version
78

89
from fromager import overrides
910

1011

11-
def test_patches_for_source_dir(tmp_path: pathlib.Path):
12+
def test_patches_for_requirement(tmp_path: pathlib.Path):
1213
patches_dir = tmp_path / "patches"
1314
patches_dir.mkdir()
1415

1516
project_patch_dir = patches_dir / "project-1.2.3"
1617
project_patch_dir.mkdir()
1718

18-
project_variant_patch_dir = patches_dir / "project-1.2.3-variant"
19-
project_variant_patch_dir.mkdir()
20-
2119
p1 = project_patch_dir / "001.patch"
2220
p2 = project_patch_dir / "002.patch"
2321
np1 = project_patch_dir / "not-a-patch.txt"
24-
p3 = project_variant_patch_dir / "003.patch"
25-
np2 = project_variant_patch_dir / "not-a-patch.txt"
2622

2723
# Create all of the test files
28-
for p in [p1, p2, p3]:
24+
for p in [p1, p2]:
2925
p.write_text("this is a patch file")
30-
for f in [np1, np2]:
26+
for f in [np1]:
3127
f.write_text("this is not a patch file")
3228

33-
results = list(overrides.patches_for_source_dir(patches_dir, "project-1.2.3"))
34-
assert results == [p1, p2]
35-
3629
results = list(
37-
overrides.patches_for_source_dir(patches_dir, "project-1.2.3-variant")
30+
overrides.patches_for_requirement(
31+
patches_dir=patches_dir,
32+
req=Requirement("project"),
33+
version=Version("1.2.3"),
34+
)
3835
)
39-
assert results == [p3]
36+
assert results == [p1, p2]
4037

4138

4239
def test_invoke_override_with_exact_args():
@@ -73,86 +70,3 @@ def default_foo(arg1):
7370
assert overrides.find_and_invoke(
7471
"pkg", "foo", default_foo, arg1="value1", arg2="value2"
7572
)
76-
77-
78-
def test_regex_dummy_package(tmp_path: pathlib.Path):
79-
req_name = "foo"
80-
patches_dir = tmp_path / "patches_dir"
81-
patches_dir.mkdir()
82-
83-
lst = [
84-
patches_dir / "foo-1.1.0",
85-
patches_dir / "foo-bar-2.0.0",
86-
patches_dir / "foo-v2.3.0",
87-
patches_dir / "foo-bar-bar-v2.3.1",
88-
patches_dir / "foo-bar-v5.5.5",
89-
patches_dir / "foo-3.4.4",
90-
patches_dir / "foo-v2.3.0.1",
91-
]
92-
93-
expected = [
94-
patches_dir / "foo-1.1.0",
95-
patches_dir / "foo-v2.3.0",
96-
patches_dir / "foo-3.4.4",
97-
patches_dir / "foo-v2.3.0.1",
98-
]
99-
100-
actual = overrides._filter_patches_based_on_req(lst, req_name)
101-
assert len(expected) == len(actual)
102-
assert expected == actual
103-
104-
105-
def test_regex_for_deepspeed(tmp_path: pathlib.Path):
106-
req_name = "deepspeed"
107-
patches_dir = tmp_path / "patches_dir"
108-
patches_dir.mkdir()
109-
110-
lst = [
111-
patches_dir / "deepspeed-1.1.0",
112-
patches_dir / "deepspeed-deep-2.0.0",
113-
patches_dir / "deepspeed-v2.3.0.post1",
114-
patches_dir / "deepspeed-v5.5.5",
115-
patches_dir / "deepspeed-3.4.4",
116-
patches_dir / "deepspeed-sdg-3.4.4",
117-
]
118-
119-
expected = [
120-
patches_dir / "deepspeed-1.1.0",
121-
patches_dir / "deepspeed-v2.3.0.post1",
122-
patches_dir / "deepspeed-v5.5.5",
123-
patches_dir / "deepspeed-3.4.4",
124-
]
125-
126-
actual = overrides._filter_patches_based_on_req(lst, req_name)
127-
assert len(expected) == len(actual)
128-
assert expected == actual
129-
130-
131-
def test_regex_for_vllm(tmp_path: pathlib.Path):
132-
req_name = "vllm"
133-
patches_dir = tmp_path / "patches_dir"
134-
patches_dir.mkdir()
135-
136-
lst = [
137-
patches_dir / "vllm-1.1.0.9",
138-
patches_dir / "vllm-llm-2.1.0.0",
139-
patches_dir / "vllm-v2.3.5.0.post1",
140-
patches_dir / "vllm-v5.5.5.1",
141-
]
142-
143-
expected = [
144-
patches_dir / "vllm-1.1.0.9",
145-
patches_dir / "vllm-v2.3.5.0.post1",
146-
patches_dir / "vllm-v5.5.5.1",
147-
]
148-
149-
actual = overrides._filter_patches_based_on_req(lst, req_name)
150-
assert len(expected) == len(actual)
151-
assert expected == actual
152-
153-
154-
def test_get_dist_info():
155-
fromager_version = metadata.version("fromager")
156-
plugin_dist, plugin_version = overrides._get_dist_info("fromager.submodule")
157-
assert plugin_dist == "fromager"
158-
assert plugin_version == fromager_version

tests/test_sources.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ def test_patch_sources_apply_unversioned_and_versioned(
121121
source_root_dir = tmp_path / "deepspeed-0.5.0"
122122
source_root_dir.mkdir()
123123

124-
sources.patch_source(tmp_context, source_root_dir, Requirement("deepspeed==0.5.0"))
124+
sources.patch_source(
125+
ctx=tmp_context,
126+
source_root_dir=source_root_dir,
127+
req=Requirement("deepspeed==0.5.0"),
128+
version=Version("0.5.0"),
129+
)
125130
assert apply_patch.call_count == 2
126131
apply_patch.assert_has_calls(
127132
[
@@ -151,10 +156,15 @@ def test_patch_sources_apply_only_unversioned(
151156
unversioned_patch_file = deepspeed_unversioned_patch / "deepspeed-update.patch"
152157
unversioned_patch_file.write_text("This is a test patch")
153158

154-
source_root_dir = tmp_path / "deepspeed"
159+
source_root_dir = tmp_path / "deepspeed-0.5.0"
155160
source_root_dir.mkdir()
156161

157-
sources.patch_source(tmp_context, source_root_dir, Requirement("deepspeed==0.5.0"))
162+
sources.patch_source(
163+
ctx=tmp_context,
164+
source_root_dir=source_root_dir,
165+
req=Requirement("deepspeed"),
166+
version=Version("0.6.0"),
167+
)
158168
assert apply_patch.call_count == 1
159169
apply_patch.assert_has_calls(
160170
[
@@ -179,7 +189,11 @@ def test_warning_for_older_patch(mock, tmp_path: pathlib.Path):
179189
source_root_dir = tmp_path / "deepspeed-0.6.0"
180190
source_root_dir.mkdir()
181191

182-
sources._warn_for_old_patch(source_root_dir, patches_dir)
192+
sources._warn_for_old_patch(
193+
req=Requirement("deepspeed"),
194+
version=Version("0.6.0"),
195+
patches_dir=patches_dir,
196+
)
183197
mock.assert_called()
184198

185199

@@ -199,5 +213,9 @@ def test_warning_for_older_patch_different_req(mock, tmp_path: pathlib.Path):
199213
source_root_dir = tmp_path / "deepspeed-0.5.0"
200214
source_root_dir.mkdir()
201215

202-
sources._warn_for_old_patch(source_root_dir, patches_dir)
216+
sources._warn_for_old_patch(
217+
req=Requirement("deepspeed"),
218+
version=Version("0.5.0"),
219+
patches_dir=patches_dir,
220+
)
203221
mock.assert_not_called()

0 commit comments

Comments
 (0)