Skip to content

Commit 107f518

Browse files
committed
ci: check-torch-cuda (assert correct version)
1 parent a96c17c commit 107f518

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

.gitlab-ci.yml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@ stages:
77
.default-before-script: &default-before-script
88
- python --version
99

10+
.check-torch-cuda: &check-torch-cuda
11+
- |
12+
TORCH_VERSION="$TORCH_VERSION" python <<EOF
13+
import os, torch
14+
from packaging import version
15+
print(f"{torch.__version__=}\n{torch.version.cuda=}\n{torch.cuda.is_available()=}")
16+
expected_torch_version = os.environ["TORCH_VERSION"]
17+
assert (
18+
version.parse(torch.__version__).base_version
19+
== version.parse(expected_torch_version).base_version
20+
), f"Expected torch.__version__={expected_torch_version}, but got {torch.__version__=}"
21+
# assert torch.cuda.is_available(), "CUDA is not available"
22+
EOF
23+
1024
wheel:
1125
image: python:$PYTHON_VERSION-buster
1226
stage: build
@@ -78,9 +92,20 @@ test:
7892
image: pytorch/pytorch:$PYTORCH_IMAGE
7993
before_script:
8094
- *default-before-script
95+
- |
96+
if [ "$PYTORCH_IMAGE" == "1.13.1-cuda11.6-cudnn8-devel" ]; then
97+
CUDA_VERSION_NAME=cu116
98+
TORCH_VERSION=1.13.1+cu116
99+
TORCHVISION_VERSION=0.14.1+cu116
100+
elif [ "$PYTORCH_IMAGE" == "2.2.0-cuda11.8-cudnn8-devel" ]; then
101+
CUDA_VERSION_NAME=cu118
102+
TORCH_VERSION=2.2.0+cu118
103+
TORCHVISION_VERSION=0.17.0+cu118
104+
fi
81105
- python -m pip install -U pip
82106
- python -m pip install -e .
83107
- python -m pip install pytest pytest-cov plotly
108+
- *check-torch-cuda
84109
- |
85110
PYTEST_ARGS=(--cov=compressai --capture=no tests)
86111
if [ "$CI_COMMIT_BRANCH" != "master" ]; then

0 commit comments

Comments
 (0)