Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed `compare_version` if runtime error ([#427](https://github.com/Lightning-AI/utilities/pull/427))
- Remove deprecated `pkg_resources` usage for `setuptools >= 82` compatibility ([#473](https://github.com/Lightning-AI/utilities/pull/473))


---
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[build-system]
requires = [
"packaging",
"setuptools",
"wheel",
]
Expand Down
18 changes: 14 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/usr/bin/env python
import glob
import os
from collections.abc import Iterator
from importlib.util import module_from_spec, spec_from_file_location

from pkg_resources import parse_requirements
from packaging.requirements import Requirement
from setuptools import find_packages, setup

_PATH_ROOT = os.path.realpath(os.path.dirname(__file__))
Expand All @@ -20,9 +21,19 @@ def _load_py_module(fname: str, pkg: str = "lightning_utilities"):

about = _load_py_module("__about__.py")


# load basic requirements
def _parse_requirements(lines: list[str]) -> Iterator[str]:
"""Parse requirements from lines using packaging."""
for line in lines:
line = line.strip()
if not line or line.startswith("#"):
continue
yield str(Requirement(line))


with open(os.path.join(_PATH_REQUIRE, "core.txt")) as fp:
requirements = list(map(str, parse_requirements(fp.readlines())))
requirements = list(_parse_requirements(fp.readlines()))


# make extras as automated loading
Expand All @@ -36,8 +47,7 @@ def _requirement_extras(path_req: str = _PATH_REQUIRE) -> dict:
continue
name, _ = os.path.splitext(fname)
with open(fpath) as fp:
reqs = parse_requirements(fp.readlines())
extras[name] = list(map(str, reqs))
extras[name] = list(_parse_requirements(fp.readlines()))
return extras


Expand Down
5 changes: 2 additions & 3 deletions src/lightning_utilities/docs/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,15 @@ def _load_pypi_versions(package_name: str) -> list[str]:
['0.9', '0.10', '0.11', '0.12', ...]

"""
from distutils.version import LooseVersion

import requests
from packaging.version import Version

url = f"https://pypi.org/pypi/{package_name}/json"
data = requests.get(url, timeout=10).json()
versions = data["releases"].keys()
# filter all version which include only numbers and dots
versions = {k for k in versions if re.match(r"^\d+(\.\d+)*$", k)}
return sorted(versions, key=LooseVersion)
return sorted(versions, key=Version)


def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str:
Expand Down
39 changes: 29 additions & 10 deletions src/lightning_utilities/install/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,30 @@

import re
from collections.abc import Iterable, Iterator
from distutils.version import LooseVersion
from pathlib import Path
from typing import Any, Optional, Union

from pkg_resources import Requirement, yield_lines # type: ignore[import-untyped]
from packaging.requirements import Requirement
from packaging.version import Version


def yield_lines(strs: Union[str, Iterable[str]]) -> Iterator[str]:
"""Yield lines from a string or iterable, handling line continuations.

Args:
strs: Either an iterable of strings or a single multi-line string.

Yields:
Individual lines with continuations resolved.

"""
if isinstance(strs, str):
strs = strs.splitlines()
for line in strs:
line = line.strip()
if not line or line.startswith("#"):
continue
yield line


class _RequirementWithComment(Requirement):
Expand Down Expand Up @@ -77,23 +96,23 @@ def adjust(self, unfreeze: str) -> str:
if self.strict:
return f"{out} {self.strict_string}"
if unfreeze == "major":
for operator, version in self.specs:
if operator in ("<", "<="):
major = LooseVersion(version).version[0]
for spec in self.specifier:
if spec.operator in ("<", "<="):
major = Version(spec.version).major
# replace upper bound with major version increased by one
return out.replace(f"{operator}{version}", f"<{int(major) + 1}.0")
return out.replace(f"{spec.operator}{spec.version}", f"<{int(major) + 1}.0")
elif unfreeze == "all":
for operator, version in self.specs:
if operator in ("<", "<="):
for spec in self.specifier:
if spec.operator in ("<", "<="):
# drop upper bound
return out.replace(f"{operator}{version},", "")
return out.replace(f"{spec.operator}{spec.version},", "")
elif unfreeze != "none":
raise ValueError(f"Unexpected unfreeze: {unfreeze!r} value.")
return out


def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:
r"""Adapted from ``pkg_resources.parse_requirements`` to include comments and pip arguments.
r"""Parse requirement lines preserving comments and pip arguments.

Parses a sequence or string of requirement lines, preserving trailing comments and associating any
preceding pip arguments (``--...``) with the subsequent requirement. Lines starting with ``-r`` or
Expand Down
Empty file.
129 changes: 129 additions & 0 deletions tests/unittests/install/test_requirements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from pathlib import Path

import pytest

from lightning_utilities.install.requirements import (
_parse_requirements,
_RequirementWithComment,
load_requirements,
yield_lines,
)

_PATH_ROOT = Path(__file__).parent.parent.parent.parent


def test_yield_lines_from_list():
assert list(yield_lines(["foo", " bar ", "", "# comment", "baz"])) == ["foo", "bar", "baz"]


def test_yield_lines_from_string():
assert list(yield_lines("foo\n bar \n\n# comment\nbaz")) == ["foo", "bar", "baz"]


def test_yield_lines_empty():
assert list(yield_lines([])) == []
assert list(yield_lines("")) == []


def test_requirement_with_comment_attributes():
req = _RequirementWithComment("arrow>=1.0", comment="# my comment")
assert req.name == "arrow"
assert req.comment == "# my comment"
assert req.pip_argument is None
assert req.strict is False


def test_requirement_with_comment_strict():
assert _RequirementWithComment("arrow>=1.0", comment="# strict").strict is True
assert _RequirementWithComment("arrow>=1.0", comment="# Strict pinning").strict is True


def test_requirement_with_comment_pip_argument():
req = _RequirementWithComment("arrow>=1.0", pip_argument="--extra-index-url https://x")
assert req.pip_argument == "--extra-index-url https://x"

with pytest.raises(RuntimeError, match="wrong pip argument"):
_RequirementWithComment("arrow>=1.0", pip_argument="")


def test_adjust_none():
assert _RequirementWithComment("arrow<=1.2,>=1.0").adjust("none") == "arrow<=1.2,>=1.0"
assert (
_RequirementWithComment("arrow<=1.2,>=1.0", comment="# strict").adjust("none") == "arrow<=1.2,>=1.0 # strict"
)


def test_adjust_all():
assert _RequirementWithComment("arrow<=1.2,>=1.0").adjust("all") == "arrow>=1.0"
assert _RequirementWithComment("arrow<=1.2,>=1.0", comment="# strict").adjust("all") == "arrow<=1.2,>=1.0 # strict"
assert _RequirementWithComment("arrow").adjust("all") == "arrow"


def test_adjust_major():
assert _RequirementWithComment("arrow>=1.2.0, <=1.2.2").adjust("major") == "arrow<2.0,>=1.2.0"
assert _RequirementWithComment("lib>=0.5, <=0.9").adjust("major") == "lib<1.0,>=0.5"
assert (
_RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major")
== "arrow<=1.2.2,>=1.2.0 # strict"
)
assert _RequirementWithComment("arrow>=1.2.0").adjust("major") == "arrow>=1.2.0"


def test_adjust_invalid_unfreeze():
with pytest.raises(ValueError, match="Unexpected unfreeze"):
_RequirementWithComment("arrow>=1.0").adjust("invalid")


def test_parse_requirements_basic():
reqs = list(_parse_requirements(["# comment", "", "numpy>=1.0", "pandas<2.0"]))
assert [str(r) for r in reqs] == ["numpy>=1.0", "pandas<2.0"]


def test_parse_requirements_from_string():
reqs = list(_parse_requirements("# comment\n\nnumpy>=1.0\npandas<2.0"))
assert [str(r) for r in reqs] == ["numpy>=1.0", "pandas<2.0"]


def test_parse_requirements_preserves_comments():
reqs = list(_parse_requirements(["arrow>=1.0 # strict"]))
assert len(reqs) == 1
assert reqs[0].comment == " # strict"
assert reqs[0].strict is True


def test_parse_requirements_pip_argument():
reqs = list(_parse_requirements(["--extra-index-url https://x", "torch>=2.0"]))
assert len(reqs) == 1
assert reqs[0].pip_argument == "--extra-index-url https://x"


def test_parse_requirements_skips():
reqs = list(_parse_requirements(["-r other.txt", "pesq @ git+https://github.com/foo/bar", "numpy"]))
assert len(reqs) == 1
assert reqs[0].name == "numpy"


def test_load_requirements_core():
path_req = str(_PATH_ROOT / "requirements")
reqs = load_requirements(path_req, "core.txt", unfreeze="all")
assert len(reqs) > 0
assert any("packaging" in r for r in reqs)


def test_load_requirements_nonexistent(tmpdir):
with pytest.raises(FileNotFoundError):
load_requirements(str(tmpdir), "nonexistent.txt")


def test_load_requirements_invalid_unfreeze(tmpdir):
with pytest.raises(ValueError, match="unsupported"):
load_requirements(str(tmpdir), "x.txt", unfreeze="bad")


def test_load_requirements_unfreeze_strategies(tmpdir):
req_file = tmpdir / "test.txt"
req_file.write("arrow>=1.2.0, <=1.2.2\n")

assert load_requirements(str(tmpdir), "test.txt", unfreeze="none") == ["arrow<=1.2.2,>=1.2.0"]
assert load_requirements(str(tmpdir), "test.txt", unfreeze="major") == ["arrow<2.0,>=1.2.0"]
assert load_requirements(str(tmpdir), "test.txt", unfreeze="all") == ["arrow>=1.2.0"]
Loading