File tree Expand file tree Collapse file tree 5 files changed +79
-79
lines changed
Expand file tree Collapse file tree 5 files changed +79
-79
lines changed Original file line number Diff line number Diff line change 2121 python-version : ${{ matrix.python-version }}
2222
2323 - name : Install the project
24- run : uv sync --locked --all-extras --dev
24+ run : uv sync --reinstall -- locked --all-extras --dev
2525
2626 - name : Run unit tests
2727 run : uv run pytest -m "not slow" --cov=compressai -s tests/
Original file line number Diff line number Diff line change @@ -156,7 +156,7 @@ test:
156156 update_dict(base, delta)
157157 Path("pyproject.toml").write_text(tomlkit.dumps(base))
158158 EOF
159- - uv sync --group=test
159+ - uv sync --group=test --reinstall
160160 - *check-torch-cuda
161161 - |
162162 PYTEST_ARGS=(--cov=compressai --capture=no tests)
Original file line number Diff line number Diff line change @@ -31,8 +31,7 @@ classifiers = [
3131dependencies = [
3232 " einops" ,
3333 " matplotlib" ,
34- " numpy >= 1.24.4,<2.0; python_version < '3.10'" ,
35- " numpy >= 2.2.4; python_version >= '3.10'" ,
34+ " numpy >= 1.24.4,<2.0" ,
3635 " Pandas >= 2.0; python_version < '3.10'" ,
3736 " Pandas >= 2.2.0; python_version >= '3.10'" ,
3837 " pybind11>=2.6.0; python_version < '3.10'" , # For --no-build-isolation.
Original file line number Diff line number Diff line change 4040 import tomli as tomllib
4141
4242
43+ from packaging .version import parse as parse_version
44+
45+ try :
46+ import torch
47+
48+ torch_version = parse_version (torch .__version__ )
49+ except ImportError :
50+ torch_version = parse_version ("0.0" )
51+
52+ # Logic to determine numpy version based on torch version
53+ if torch_version < parse_version ("2.2" ):
54+ numpy_req = "numpy<2"
55+ else :
56+ numpy_req = "numpy>=2.0"
57+
58+
4359with open ("pyproject.toml" , "rb" ) as f :
4460 pyproject = tomllib .load (f )
4561
@@ -106,4 +122,7 @@ def find_sources(path):
106122 name = package_name ,
107123 ext_modules = get_extensions (),
108124 cmdclass = {"build_ext" : build_ext },
125+ install_requires = [
126+ numpy_req ,
127+ ],
109128)
You can’t perform that action at this time.
0 commit comments