|
14 | 14 | from typing import TYPE_CHECKING
|
15 | 15 |
|
16 | 16 | import packaging.requirements
|
| 17 | +import packaging.specifiers |
17 | 18 | import packaging.utils
|
18 | 19 | import packaging.version
|
19 | 20 | import pytest
|
@@ -233,7 +234,68 @@ class VersionData:
|
233 | 234 | pytest.fail(f"{name} has multiple versions: {pprint.pformat(mapping)}")
|
234 | 235 |
|
235 | 236 |
|
236 |
| -# TODO(jdanek): ^^^ should also check pyproject.tomls, in fact checking there is more useful than in manifests |
| 237 | +def test_image_pyprojects_version_alignment(subtests: pytest_subtests.plugin.SubTests): |
| 238 | + requirements = defaultdict(list) |
| 239 | + for file in PROJECT_ROOT.glob("**/pyproject.toml"): |
| 240 | + logging.info(file) |
| 241 | + directory = file.parent # "ubi9-python-3.11" |
| 242 | + try: |
| 243 | + _ubi, _lang, _python = directory.name.split("-") |
| 244 | + except ValueError: |
| 245 | + logging.debug(f"skipping {directory.name}/pyproject.toml as it is not an image directory") |
| 246 | + continue |
| 247 | + |
| 248 | + if _skip_unimplemented_manifests(directory, call_skip=False): |
| 249 | + continue |
| 250 | + |
| 251 | + pyproject = tomllib.loads(file.read_text()) |
| 252 | + for d in pyproject["project"]["dependencies"]: |
| 253 | + requirement = packaging.requirements.Requirement(d) |
| 254 | + requirements[requirement.name].append(requirement.specifier) |
| 255 | + |
| 256 | + # TODO(jdanek): review these, if any are unwarranted |
| 257 | + ignored_exceptions: tuple[tuple[str, tuple[str, ...]], ...] = ( |
| 258 | + # ("package name", ("allowed specifier 1", "allowed specifier 2", ...)) |
| 259 | + ("setuptools", ("~=78.1.1", "==78.1.1")), |
| 260 | + ("wheel", ("==0.45.1", "~=0.45.1")), |
| 261 | + ("tensorboard", ("~=2.18.0", "~=2.19.0")), |
| 262 | + ("torch", ("==2.6.0", "==2.6.0+cu126", "==2.6.0+rocm6.2.4")), |
| 263 | + ("torchvision", ("==0.21.0", "==0.21.0+cu126", "==0.21.0+rocm6.2.4")), |
| 264 | + ("matplotlib", ("~=3.10.1", "~=3.10.3")), |
| 265 | + ("numpy", ("~=2.2.3", "<2.0.0", "~=1.26.4")), |
| 266 | + ("pandas", ("~=2.2.3", "~=1.5.3")), |
| 267 | + ("scikit-learn", ("~=1.6.1", "~=1.7.0")), |
| 268 | + ("codeflare-sdk", ("~=0.29.0", "~=0.30.0")), |
| 269 | + ("ipython-genutils", (">=0.2.0", "~=0.2.0")), |
| 270 | + ("jinja2", (">=3.1.6", "~=3.1.6")), |
| 271 | + ("jupyter-client", ("~=8.6.3", ">=8.6.3")), |
| 272 | + ("requests", ("~=2.32.3", ">=2.0.0")), |
| 273 | + ("urllib3", ("~=2.5.0", "~=2.3.0")), |
| 274 | + ("transformers", ("<5.0,>4.0", "~=4.55.0")), |
| 275 | + ("datasets", ("", "~=3.4.1")), |
| 276 | + ("accelerate", ("!=1.1.0,>=0.20.3", "~=1.5.2")), |
| 277 | + ("kubeflow-training", ("==1.9.0", "==1.9.2", "==1.9.3")), |
| 278 | + ("jupyter-bokeh", ("~=3.0.5", "~=4.0.5")), |
| 279 | + ("jupyterlab-lsp", ("~=5.1.0", "~=5.1.1")), |
| 280 | + ("jupyterlab-widgets", ("~=3.0.13", "~=3.0.15")), |
| 281 | + ) |
| 282 | + |
| 283 | + for name, data in requirements.items(): |
| 284 | + if len(set(data)) == 1: |
| 285 | + continue |
| 286 | + |
| 287 | + with subtests.test(msg=f"checking versions of {name} across all pyproject.tomls"): |
| 288 | + exception = next((it for it in ignored_exceptions if it[0] == name), None) |
| 289 | + if exception: |
| 290 | + # exception may save us from failing |
| 291 | + if set(data) == {packaging.specifiers.SpecifierSet(e) for e in exception[1]}: |
| 292 | + continue |
| 293 | + else: |
| 294 | + pytest.fail( |
| 295 | + f"{name} is allowed to have {exception[1]} but actually has more specifiers: {pprint.pformat(set(data))}" |
| 296 | + ) |
| 297 | + # all hope is lost, the check has failed |
| 298 | + pytest.fail(f"{name} has multiple specifiers: {pprint.pformat(data)}") |
237 | 299 |
|
238 | 300 |
|
239 | 301 | def test_files_that_should_be_same_are_same(subtests: pytest_subtests.plugin.SubTests):
|
|
0 commit comments