|
7 | 7 | .default-before-script: &default-before-script |
8 | 8 | - python --version |
9 | 9 |
|
| 10 | +.check-torch-cuda: &check-torch-cuda |
| 11 | + - | |
| 12 | + TORCH_VERSION="$TORCH_VERSION" python <<EOF |
| 13 | + import os, torch |
| 14 | + from packaging import version |
| 15 | + print(f"{torch.__version__=}\n{torch.version.cuda=}\n{torch.cuda.is_available()=}") |
| 16 | + expected_torch_version = os.environ["TORCH_VERSION"] |
| 17 | + assert ( |
| 18 | + version.parse(torch.__version__).base_version |
| 19 | + == version.parse(expected_torch_version).base_version |
| 20 | + ), f"Expected torch.__version__={expected_torch_version}, but got {torch.__version__=}" |
| 21 | + # assert torch.cuda.is_available(), "CUDA is not available" |
| 22 | + EOF |
| 23 | +
|
10 | 24 | wheel: |
11 | 25 | image: python:$PYTHON_VERSION-buster |
12 | 26 | stage: build |
|
78 | 92 | image: pytorch/pytorch:$PYTORCH_IMAGE |
79 | 93 | before_script: |
80 | 94 | - *default-before-script |
| 95 | + - | |
| 96 | + if [ "$PYTORCH_IMAGE" == "1.13.1-cuda11.6-cudnn8-devel" ]; then |
| 97 | + CUDA_VERSION_NAME=cu116 |
| 98 | + TORCH_VERSION=1.13.1+cu116 |
| 99 | + TORCHVISION_VERSION=0.14.1+cu116 |
| 100 | + elif [ "$PYTORCH_IMAGE" == "2.2.0-cuda11.8-cudnn8-devel" ]; then |
| 101 | + CUDA_VERSION_NAME=cu118 |
| 102 | + TORCH_VERSION=2.2.0+cu118 |
| 103 | + TORCHVISION_VERSION=0.17.0+cu118 |
| 104 | + fi |
81 | 105 | - python -m pip install -U pip |
82 | 106 | - python -m pip install -e . |
83 | 107 | - python -m pip install pytest pytest-cov plotly |
| 108 | + - *check-torch-cuda |
84 | 109 | - | |
85 | 110 | PYTEST_ARGS=(--cov=compressai --capture=no tests) |
86 | 111 | if [ "$CI_COMMIT_BRANCH" != "master" ]; then |
|
0 commit comments