Skip to content

Commit 1fecb8e

Browse files
authored
RHAIENG-308, opendatahub-io#2242: tests(make): add pyproject version alignment test and handle allowed specifier divergences (opendatahub-io#2276)
* Add test_image_pyprojects_version_alignment to compare dependency specifiers across all pyproject.toml files * Import packaging.specifiers and store actual SpecifierSet objects instead of string reprs * Introduce ignored_exceptions for known, acceptable specifier differences (e.g., torch variants, numpy caps) * Compare specifiers using SpecifierSet equality and wrap checks in subtests for clearer reporting * Minor cleanup: remove unused pyprojects variable
1 parent ae521a6 commit 1fecb8e

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

tests/test_main.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import TYPE_CHECKING
1515

1616
import packaging.requirements
17+
import packaging.specifiers
1718
import packaging.utils
1819
import packaging.version
1920
import pytest
@@ -233,7 +234,68 @@ class VersionData:
233234
pytest.fail(f"{name} has multiple versions: {pprint.pformat(mapping)}")
234235

235236

236-
# TODO(jdanek): ^^^ should also check pyproject.tomls, in fact checking there is more useful than in manifests
237+
def test_image_pyprojects_version_alignment(subtests: pytest_subtests.plugin.SubTests):
238+
requirements = defaultdict(list)
239+
for file in PROJECT_ROOT.glob("**/pyproject.toml"):
240+
logging.info(file)
241+
directory = file.parent # "ubi9-python-3.11"
242+
try:
243+
_ubi, _lang, _python = directory.name.split("-")
244+
except ValueError:
245+
logging.debug(f"skipping {directory.name}/pyproject.toml as it is not an image directory")
246+
continue
247+
248+
if _skip_unimplemented_manifests(directory, call_skip=False):
249+
continue
250+
251+
pyproject = tomllib.loads(file.read_text())
252+
for d in pyproject["project"]["dependencies"]:
253+
requirement = packaging.requirements.Requirement(d)
254+
requirements[requirement.name].append(requirement.specifier)
255+
256+
# TODO(jdanek): review these, if any are unwarranted
257+
ignored_exceptions: tuple[tuple[str, tuple[str, ...]], ...] = (
258+
# ("package name", ("allowed specifier 1", "allowed specifier 2", ...))
259+
("setuptools", ("~=78.1.1", "==78.1.1")),
260+
("wheel", ("==0.45.1", "~=0.45.1")),
261+
("tensorboard", ("~=2.18.0", "~=2.19.0")),
262+
("torch", ("==2.6.0", "==2.6.0+cu126", "==2.6.0+rocm6.2.4")),
263+
("torchvision", ("==0.21.0", "==0.21.0+cu126", "==0.21.0+rocm6.2.4")),
264+
("matplotlib", ("~=3.10.1", "~=3.10.3")),
265+
("numpy", ("~=2.2.3", "<2.0.0", "~=1.26.4")),
266+
("pandas", ("~=2.2.3", "~=1.5.3")),
267+
("scikit-learn", ("~=1.6.1", "~=1.7.0")),
268+
("codeflare-sdk", ("~=0.29.0", "~=0.30.0")),
269+
("ipython-genutils", (">=0.2.0", "~=0.2.0")),
270+
("jinja2", (">=3.1.6", "~=3.1.6")),
271+
("jupyter-client", ("~=8.6.3", ">=8.6.3")),
272+
("requests", ("~=2.32.3", ">=2.0.0")),
273+
("urllib3", ("~=2.5.0", "~=2.3.0")),
274+
("transformers", ("<5.0,>4.0", "~=4.55.0")),
275+
("datasets", ("", "~=3.4.1")),
276+
("accelerate", ("!=1.1.0,>=0.20.3", "~=1.5.2")),
277+
("kubeflow-training", ("==1.9.0", "==1.9.2", "==1.9.3")),
278+
("jupyter-bokeh", ("~=3.0.5", "~=4.0.5")),
279+
("jupyterlab-lsp", ("~=5.1.0", "~=5.1.1")),
280+
("jupyterlab-widgets", ("~=3.0.13", "~=3.0.15")),
281+
)
282+
283+
for name, data in requirements.items():
284+
if len(set(data)) == 1:
285+
continue
286+
287+
with subtests.test(msg=f"checking versions of {name} across all pyproject.tomls"):
288+
exception = next((it for it in ignored_exceptions if it[0] == name), None)
289+
if exception:
290+
# exception may save us from failing
291+
if set(data) == {packaging.specifiers.SpecifierSet(e) for e in exception[1]}:
292+
continue
293+
else:
294+
pytest.fail(
295+
f"{name} is allowed to have {exception[1]} but actually has more specifiers: {pprint.pformat(set(data))}"
296+
)
297+
# all hope is lost, the check has failed
298+
pytest.fail(f"{name} has multiple specifiers: {pprint.pformat(data)}")
237299

238300

239301
def test_files_that_should_be_same_are_same(subtests: pytest_subtests.plugin.SubTests):

0 commit comments

Comments
 (0)