Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ ENV NV_CUDA_CUDART_DEV_VERSION=12.1.55-1 \
NV_NVML_DEV_VERSION=12.1.55-1 \
NV_LIBCUBLAS_DEV_VERSION=12.1.0.26-1 \
NV_LIBNPP_DEV_VERSION=12.0.2.50-1 \
NV_LIBNCCL_DEV_PACKAGE_VERSION=2.18.3-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_VERSION=2.18.3-1+cuda12.1 \
NV_CUDNN9_CUDA_VERSION=9.6.0.74-1

RUN dnf config-manager \
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
Expand All @@ -103,6 +104,15 @@ RUN dnf config-manager \
libnccl-devel-${NV_LIBNCCL_DEV_PACKAGE_VERSION} \
&& dnf clean all

# opening connection for too long in one go was resulting in timeouts
RUN dnf config-manager \
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
&& dnf clean packages \
&& dnf install -y \
libcusparselt0 libcusparselt-devel \
cudnn9-cuda-12-6-${NV_CUDNN9_CUDA_VERSION} \
&& dnf clean all

ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs"

FROM cuda-devel AS python-installations
Expand Down Expand Up @@ -138,7 +148,8 @@ RUN if [[ -z "${WHEEL_VERSION}" ]]; \
RUN --mount=type=cache,target=/home/${USER}/.cache/pip,uid=${USER_UID} \
python -m pip install --user wheel && \
python -m pip install --user "$(head bdist_name)" && \
python -m pip install --user "$(head bdist_name)[flash-attn]"
python -m pip install --user "$(head bdist_name)[flash-attn]" && \
python -m pip install --user "$(head bdist_name)[mamba]"

# fms_acceleration_peft = PEFT-training, e.g., 4bit QLoRA
# fms_acceleration_foak = Fused LoRA and triton kernels
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ aim = ["aim>=3.19.0,<4.0"]
mlflow = ["mlflow"]
fms-accel = ["fms-acceleration>=0.6"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
mamba = ["mamba_ssm[causal-conv1d] @ git+https://github.com/state-spaces/mamba.git"]
scanner-dev = ["HFResourceScanner>=0.1.0"]


Expand Down