Skip to content

Commit f8416be

Browse files
otajlexierule
authored andcommitted
[CI] Trick Bagua into installing appropriate wheel in GPU tests (#14380)
Bagua trick needs to be replicated on everywhere applicable
1 parent 48ddb0a commit f8416be

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

.azure/gpu-tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ jobs:
7373
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
7474
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'bagua' not in line] ; open(fname, 'w').writelines(lines)"
7575
CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
76-
pip install "bagua-cuda$CUDA_VERSION_MM>=0.9.0"
76+
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [115,113,111,102] if $CUDA_VERSION_MM >= ver][0])")
77+
pip install "bagua-cuda$CUDA_VERSION_BAGUA>=0.9.0"
7778
pip install -e .[strategies]
7879
pip install -U deepspeed # TODO: remove when docker images are upgraded
7980
pip install --requirement requirements/pytorch/devel.txt

dockers/base-conda/Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ RUN \
141141
RUN \
142142
# install Bagua
143143
CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
144-
pip install "bagua-cuda$CUDA_VERSION_MM==0.9.0" && \
145-
python -c "import bagua_core; bagua_core.install_deps()" && \
144+
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [115,113,111,102] if $CUDA_VERSION_MM >= ver][0])") && \
145+
pip install "bagua-cuda$CUDA_VERSION_BAGUA==0.9.0" && \
146+
if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then python -c "import bagua_core; bagua_core.install_deps()"; fi && \
146147
python -c "import bagua; print(bagua.__version__)"
147148

148149
RUN \

0 commit comments

Comments
 (0)