Skip to content

Commit 7cbae3e

Browse files
committed
Update .gitlab-ci.yml file
1 parent 97dd73c commit 7cbae3e

File tree

5 files changed

+79
-79
lines changed

5 files changed

+79
-79
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
python-version: ${{ matrix.python-version }}
2222

2323
- name: Install the project
24-
run: uv sync --locked --all-extras --dev
24+
run: uv sync --reinstall --locked --all-extras --dev
2525

2626
- name: Run unit tests
2727
run: uv run pytest -m "not slow" --cov=compressai -s tests/

.gitlab-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ test:
156156
update_dict(base, delta)
157157
Path("pyproject.toml").write_text(tomlkit.dumps(base))
158158
EOF
159-
- uv sync --group=test
159+
- uv sync --group=test --reinstall
160160
- *check-torch-cuda
161161
- |
162162
PYTEST_ARGS=(--cov=compressai --capture=no tests)

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ classifiers = [
3131
dependencies = [
3232
"einops",
3333
"matplotlib",
34-
"numpy >= 1.24.4,<2.0; python_version < '3.10'",
35-
"numpy >= 2.2.4; python_version >= '3.10'",
34+
"numpy >= 1.24.4,<2.0",
3635
"Pandas >= 2.0; python_version < '3.10'",
3736
"Pandas >= 2.2.0; python_version >= '3.10'",
3837
"pybind11>=2.6.0; python_version < '3.10'", # For --no-build-isolation.

setup.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@
4040
import tomli as tomllib
4141

4242

43+
from packaging.version import parse as parse_version
44+
45+
try:
46+
import torch
47+
48+
torch_version = parse_version(torch.__version__)
49+
except ImportError:
50+
torch_version = parse_version("0.0")
51+
52+
# Logic to determine numpy version based on torch version
53+
if torch_version < parse_version("2.2"):
54+
numpy_req = "numpy<2"
55+
else:
56+
numpy_req = "numpy>=2.0"
57+
58+
4359
with open("pyproject.toml", "rb") as f:
4460
pyproject = tomllib.load(f)
4561

@@ -106,4 +122,7 @@ def find_sources(path):
106122
name=package_name,
107123
ext_modules=get_extensions(),
108124
cmdclass={"build_ext": build_ext},
125+
install_requires=[
126+
numpy_req,
127+
],
109128
)

0 commit comments

Comments
 (0)