Skip to content

Commit dae45e3

Browse files
committed
Add cuda-entrypoint.sh and update Dockerfile-cuda
1 parent 000a081 commit dae45e3

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

Dockerfile-cuda

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,25 +126,26 @@ ARG DEFAULT_USE_FLASH_ATTENTION=True
126126
ENV HUGGINGFACE_HUB_CACHE=/data \
127127
PORT=80 \
128128
USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION \
129-
LD_LIBRARY_PATH="/usr/local/cuda/compat:${LD_LIBRARY_PATH}"
129+
LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
130130

131131
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
132132
ca-certificates \
133133
libssl-dev \
134134
curl \
135-
cuda-compat-13-1 \
135+
cuda-compat-12-9 \
136136
&& rm -rf /var/lib/apt/lists/*
137137

138+
COPY --chmod=775 cuda-entrypoint.sh entrypoint.sh
139+
138140
FROM base AS grpc
139141

140142
COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
141143

142-
ENTRYPOINT ["text-embeddings-router"]
144+
ENTRYPOINT ["./entrypoint.sh"]
143145
CMD ["--json-output"]
144146

145147
FROM base
146148

147149
COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
148-
149-
ENTRYPOINT ["text-embeddings-router"]
150+
ENTRYPOINT ["./entrypoint.sh"]
150151
CMD ["--json-output"]

cuda-entrypoint.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
if ! command -v nvidia-smi &>/dev/null; then
4+
echo "Error: 'nvidia-smi' command not found."
5+
exit 1
6+
fi
7+
8+
# NOTE: Given that we need to support CUDA versions earlier than CUDA 12.9.1, we
9+
# need to include the `cuda-compat-12-9` in `LD_LIBRARY_PATH` when the host CUDA
10+
# version is lower than that; whilst we shouldn't include that when CUDA is 13.0+
11+
# as otherwise it will fail due to it.
12+
if [ -d /usr/local/cuda/compat ]; then
13+
DRIVER_CUDA=$(nvidia-smi 2>/dev/null | awk '/CUDA Version/ {print $3; exit}')
14+
15+
IFS='.' read -r MAJ MIN PATCH <<EOF
16+
${DRIVER_CUDA:-0.0.0}
17+
EOF
18+
: "${MIN:=0}"
19+
: "${PATCH:=0}"
20+
21+
DRIVER_INT=$((10#${MAJ} * 10000 + 10#${MIN} * 100 + 10#${PATCH}))
22+
TARGET_INT=$((12 * 10000 + 9 * 100 + 1))
23+
24+
if [ "$DRIVER_INT" -lt "$TARGET_INT" ]; then
25+
export LD_LIBRARY_PATH="/usr/local/cuda/compat:${LD_LIBRARY_PATH}"
26+
fi
27+
fi
28+
29+
exec text-embeddings-router "$@"

0 commit comments

Comments
 (0)