@@ -3,12 +3,13 @@ ARG BASE_UBI_IMAGE_TAG=9.3-1552
33ARG PROTOC_VERSION=25.2
44ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
55# ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
6+ ARG AUTO_GPTQ_VERSION=0.7.1
67
78# match PyTorch version that was used to compile flash-attention v2 pre-built wheels
89# e.g. flash-attn v2.5.2 => torch ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240126']
910# https://github.com/Dao-AILab/flash-attention/blob/v2.5.2/.github/workflows/publish.yml#L47
1011# use nightly build index for torch .dev pre-release versions
11- ARG PYTORCH_VERSION=2.2.0
12+ ARG PYTORCH_VERSION=2.2.1
1213
1314ARG PYTHON_VERSION=3.11
1415
@@ -35,18 +36,19 @@ ENV LANG=C.UTF-8 \
3536# # CUDA Base ###################################################################
3637FROM base as cuda-base
3738
38- ENV CUDA_VERSION=11.8.0 \
39- NV_CUDA_LIB_VERSION=11.8.0-1 \
39+ # Ref: https://docs.nvidia.com/cuda/archive/12.1.0/cuda-toolkit-release-notes/
40+ ENV CUDA_VERSION=12.1.0 \
41+ NV_CUDA_LIB_VERSION=12.1.0-1 \
4042 NVIDIA_VISIBLE_DEVICES=all \
4143 NVIDIA_DRIVER_CAPABILITIES=compute,utility \
42- NV_CUDA_CUDART_VERSION=11.8.89 -1 \
43- NV_CUDA_COMPAT_VERSION=520.61.05 -1
44+ NV_CUDA_CUDART_VERSION=12.1.55 -1 \
45+ NV_CUDA_COMPAT_VERSION=530.30.02 -1
4446
4547RUN dnf config-manager \
4648 --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
4749 && dnf install -y \
48- cuda-cudart-11-8 -${NV_CUDA_CUDART_VERSION} \
49- cuda-compat-11-8 -${NV_CUDA_COMPAT_VERSION} \
50+ cuda-cudart-12-1 -${NV_CUDA_CUDART_VERSION} \
51+ cuda-compat-12-1 -${NV_CUDA_COMPAT_VERSION} \
5052 && echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf \
5153 && echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf \
5254 && dnf clean all
@@ -59,22 +61,23 @@ ENV CUDA_HOME="/usr/local/cuda" \
5961# # CUDA Development ############################################################
6062FROM cuda-base as cuda-devel
6163
62- ENV NV_CUDA_CUDART_DEV_VERSION=11.8.89-1 \
63- NV_NVML_DEV_VERSION=11.8.86-1 \
64- NV_LIBCUBLAS_DEV_VERSION=11.11.3.6-1 \
65- NV_LIBNPP_DEV_VERSION=11.8.0.86-1 \
66- NV_LIBNCCL_DEV_PACKAGE_VERSION=2.15.5-1+cuda11.8
64+ # Ref: https://developer.nvidia.com/nccl/nccl-legacy-downloads
65+ ENV NV_CUDA_CUDART_DEV_VERSION=12.1.55-1 \
66+ NV_NVML_DEV_VERSION=12.1.55-1 \
67+ NV_LIBCUBLAS_DEV_VERSION=12.1.0.26-1 \
68+ NV_LIBNPP_DEV_VERSION=12.0.2.50-1 \
69+ NV_LIBNCCL_DEV_PACKAGE_VERSION=2.18.3-1+cuda12.1
6770
6871RUN dnf config-manager \
6972 --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
7073 && dnf install -y \
71- cuda-command-line-tools-11-8 -${NV_CUDA_LIB_VERSION} \
72- cuda-libraries-devel-11-8 -${NV_CUDA_LIB_VERSION} \
73- cuda-minimal-build-11-8 -${NV_CUDA_LIB_VERSION} \
74- cuda-cudart-devel-11-8 -${NV_CUDA_CUDART_DEV_VERSION} \
75- cuda-nvml-devel-11-8 -${NV_NVML_DEV_VERSION} \
76- libcublas-devel-11-8 -${NV_LIBCUBLAS_DEV_VERSION} \
77- libnpp-devel-11-8 -${NV_LIBNPP_DEV_VERSION} \
74+ cuda-command-line-tools-12-1 -${NV_CUDA_LIB_VERSION} \
75+ cuda-libraries-devel-12-1 -${NV_CUDA_LIB_VERSION} \
76+ cuda-minimal-build-12-1 -${NV_CUDA_LIB_VERSION} \
77+ cuda-cudart-devel-12-1 -${NV_CUDA_CUDART_DEV_VERSION} \
78+ cuda-nvml-devel-12-1 -${NV_NVML_DEV_VERSION} \
79+ libcublas-devel-12-1 -${NV_LIBCUBLAS_DEV_VERSION} \
80+ libnpp-devel-12-1 -${NV_LIBNPP_DEV_VERSION} \
7881 libnccl-devel-${NV_LIBNCCL_DEV_PACKAGE_VERSION} \
7982 && dnf clean all
8083
@@ -199,12 +202,12 @@ ENV PATH=/opt/tgis/bin/:$PATH
199202# Install specific version of torch
200203RUN pip install ninja==1.11.1.1 --no-cache-dir
201204RUN pip install packaging --no-cache-dir
202- RUN pip install torch==$PYTORCH_VERSION+cu118 --index-url "${PYTORCH_INDEX}/cu118 " --no-cache-dir
205+ RUN pip install torch==$PYTORCH_VERSION+cu121 --index-url "${PYTORCH_INDEX}/cu121 " --no-cache-dir
203206
204207
205208# # Build flash attention v2 ####################################################
206209FROM python-builder as flash-att-v2-builder
207- ARG FLASH_ATT_VERSION=v2.5.2
210+ ARG FLASH_ATT_VERSION=v2.5.6
208211
209212WORKDIR /usr/src/flash-attention-v2
210213
@@ -217,14 +220,15 @@ RUN MAX_JOBS=2 pip --verbose wheel --no-deps flash-attn==${FLASH_ATT_VERSION} \
217220
218221
219222# # Install auto-gptq ###########################################################
220- FROM python-builder as auto-gptq-installer
221- ARG AUTO_GPTQ_REF=ccb6386ebfde63c17c45807d38779a93cd25846f
222-
223- WORKDIR /usr/src/auto-gptq-wheel
224-
225- # numpy is required to run auto-gptq's setup.py
226- RUN pip install numpy
227- RUN DISABLE_QIGEN=1 pip wheel git+https://github.com/AutoGPTQ/AutoGPTQ@${AUTO_GPTQ_REF} --no-cache-dir --no-deps --verbose
223+ # # Uncomment if a custom autogptq build is required
224+ # FROM python-builder as auto-gptq-installer
225+ # ARG AUTO_GPTQ_REF=896d8204bc89a7cfbda42bf3314e13cf4ce20b02
226+ #
227+ # WORKDIR /usr/src/auto-gptq-wheel
228+ #
229+ # # numpy is required to run auto-gptq's setup.py
230+ # RUN pip install numpy
231+ # RUN DISABLE_QIGEN=1 pip wheel git+https://github.com/AutoGPTQ/AutoGPTQ@${AUTO_GPTQ_REF} --no-cache-dir --no-deps --verbose
228232
229233# # Build libraries #############################################################
230234FROM python-builder as build
@@ -241,18 +245,20 @@ FROM base as flash-att-v2-cache
241245COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2 /usr/src/flash-attention-v2
242246
243247
244- # # Auto gptq cached build image
245- FROM base as auto-gptq-cache
246-
247- # Copy just the wheel we built for auto-gptq
248- COPY --from=auto-gptq-installer /usr/src/auto-gptq-wheel /usr/src/auto-gptq-wheel
248+ # # Auto gptq cached build image ################################################
249+ # # Uncomment if a custom autogptq build is required
250+ # FROM base as auto-gptq-cache
251+ #
252+ # # Copy just the wheel we built for auto-gptq
253+ # COPY --from=auto-gptq-installer /usr/src/auto-gptq-wheel /usr/src/auto-gptq-wheel
249254
250255
251256# # Full set of python installations for server release #########################
252257
253258FROM python-builder as python-installations
254259
255260ARG PYTHON_VERSION
261+ ARG AUTO_GPTQ_VERSION
256262ARG SITE_PACKAGES=/opt/tgis/lib/python${PYTHON_VERSION}/site-packages
257263
258264COPY --from=build /opt/tgis /opt/tgis
@@ -265,15 +271,21 @@ RUN --mount=type=bind,from=flash-att-v2-cache,src=/usr/src/flash-attention-v2,ta
265271 pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
266272
267273# Copy over the auto-gptq wheel and install it
268- RUN --mount=type=bind,from=auto-gptq-cache,src=/usr/src/auto-gptq-wheel,target=/usr/src/auto-gptq-wheel \
269- pip install /usr/src/auto-gptq-wheel/*.whl --no-cache-dir
274+ # RUN --mount=type=bind,from=auto-gptq-cache,src=/usr/src/auto-gptq-wheel,target=/usr/src/auto-gptq-wheel \
275+ # pip install /usr/src/auto-gptq-wheel/*.whl --no-cache-dir
276+
277+ # We only need to install a custom-built auto-gptq version if we need a pre-release
278+ # or are using a PyTorch nightly version
279+ RUN pip install auto-gptq=="${AUTO_GPTQ_VERSION}" --no-cache-dir
270280
271281# Install server
272282# git is required to pull the fms-extras dependency
273283RUN dnf install -y git && dnf clean all
274284COPY proto proto
275285COPY server server
276- RUN cd server && make gen-server && pip install ".[accelerate, ibm-fms, onnx-gpu, quantize]" --no-cache-dir
286+ # Extra url is required to install cuda-12 version of onnxruntime-gpu
287+ # Ref: https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x
288+ RUN cd server && make gen-server && pip install ".[accelerate, ibm-fms, onnx-gpu, quantize]" --no-cache-dir --extra-index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
277289
278290# Patch codegen model changes into transformers 4.35
279291RUN cp server/transformers_patch/modeling_codegen.py ${SITE_PACKAGES}/transformers/models/codegen/modeling_codegen.py
0 commit comments