diff --git a/Dockerfile.ubi b/Dockerfile.ubi index eaa046c371a7..f50637dc0392 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -5,6 +5,7 @@ ARG PYTHON_VERSION=3.12 ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" ARG vllm_fa_cmake_gpu_arches='80-real;90-real' + ## Base Layer ################################################################## FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} as base ARG PYTHON_VERSION @@ -50,12 +51,22 @@ ENV CUDA_HOME="/usr/local/cuda" \ PATH="${CUDA_HOME}/bin:${PATH}" \ LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH}" + ## Python cuda base ################################################################# FROM cuda-base AS python-cuda-base ENV VIRTUAL_ENV=/opt/vllm ENV PATH="$VIRTUAL_ENV/bin:$PATH" +# install numactl and common dependencies for fastsafetensors +RUN microdnf install autoconf automake libtool make rpm-build -y && \ + microdnf download --source numactl.src && \ + NUMACTL_V=$(rpm -qp --qf "%{VERSION}-%{RELEASE}\n" numactl-*.rpm | sort -V | tail -n 1) && \ + rpm -i numactl-${NUMACTL_V}.src.rpm && \ + rpmbuild -ba /root/rpmbuild/SPECS/numactl.spec && \ + rpm -i /root/rpmbuild/RPMS/x86_64/{numactl-libs-${NUMACTL_V}.x86_64.rpm,numactl-${NUMACTL_V}.x86_64.rpm,numactl-devel-${NUMACTL_V}.x86_64.rpm} && \ + microdnf clean all + # install cuda and common dependencies RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/uv \ @@ -80,6 +91,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ -r requirements-cuda.txt \ -r requirements-dev.txt + ## Builder ##################################################################### FROM dev AS build @@ -122,6 +134,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ CMAKE_BUILD_TYPE=Release \ python3 setup.py bdist_wheel --dist-dir=dist + #################### libsodium Build IMAGE #################### FROM base as libsodium-builder @@ -139,6 +152,7 @@ RUN curl -LO https://github.com/jedisct1/libsodium/releases/download/${LIBSODIUM RUN CFLAGS="-O3 -Wall -Werror=format-security -Wno-unused-function -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection"\ ./configure --prefix="/usr/" && make -j $MAX_JOBS && make check + ## Release ##################################################################### FROM python-install AS vllm-openai ARG PYTHON_VERSION @@ -152,6 +166,7 @@ ENV PATH=$VIRTUAL_ENV/bin:$PATH ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_nvrtc/lib:${LD_LIBRARY_PATH}" ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cuda_runtime/lib:${LD_LIBRARY_PATH}" ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvtx/lib:${LD_LIBRARY_PATH}" +ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/cufile/lib:${LD_LIBRARY_PATH}" # Triton needs a CC compiler RUN microdnf install -y gcc \ @@ -202,14 +217,21 @@ WORKDIR /home/vllm ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +## Last image ################################################################## FROM vllm-openai as vllm-grpc-adapter USER root +# Installing numactl and numactl-libs for fastsafetensors +RUN --mount=type=bind,from=python-cuda-base,src=/root/rpmbuild/RPMS/x86_64,target=/workspace/RPMS \ + NUMACTL_V=$(rpm -qp --qf "%{VERSION}-%{RELEASE}\n" /workspace/RPMS/numactl-*.rpm | sort -V | tail -n 1) && \ + rpm -i /workspace/RPMS/numactl-${NUMACTL_V}.x86_64.rpm /workspace/RPMS/numactl-libs-${NUMACTL_V}.x86_64.rpm + +# Ensure correct vLLM version, vllm-tgis-adapter and cufile for fastsafetensors RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \ - HOME=/root uv pip install "$(echo /workspace/dist/*.whl)[tensorizer]" vllm-tgis-adapter==0.6.0 + HOME=/root uv pip install "$(echo /workspace/dist/*.whl)[tensorizer]" vllm-tgis-adapter==0.6.0 nvidia-cufile-cu12 ENV GRPC_PORT=8033 \ PORT=8000 \ diff --git a/docs/source/serving/weights_loading_with_fastsafetensor.rst b/docs/source/serving/weights_loading_with_fastsafetensor.rst new file mode 100644 index 000000000000..2678ae38a15e --- /dev/null +++ b/docs/source/serving/weights_loading_with_fastsafetensor.rst @@ -0,0 +1,5 @@ +Loading Model weights with fastsafetensors +=================================================================== + +Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details. +For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true`` \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 8002fbd8ee5b..f793b822423f 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,3 +8,4 @@ torch == 2.5.1 # These must be updated alongside torch torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 +fastsafetensors # Required for model loading via gpu direct storage \ No newline at end of file diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f2d9293b31a8..791ee01cd204 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -42,9 +42,10 @@ set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, - filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - get_gguf_extra_tensor_names, gguf_quant_weights_iterator, - initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, + gguf_quant_weights_iterator, initialize_dummy_weights, + np_cache_weights_iterator, pt_weights_iterator, runai_safetensors_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -307,7 +308,15 @@ def _get_weights_iterator( hf_weights_files, ) elif use_safetensors: - weights_iterator = safetensors_weights_iterator(hf_weights_files) + use_fastsafe_tensor = os.getenv('USE_FASTSAFETENSOR', + 'False').lower() == 'true' + if use_fastsafe_tensor: + logger.info("Using fastsafetensor for loading weights") + weights_iterator = fastsafetensors_weights_iterator( + hf_weights_files) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files) else: weights_iterator = pt_weights_iterator(hf_weights_files) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 8aa0c98df70d..7002df2d99a2 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -14,6 +14,7 @@ import huggingface_hub.constants import numpy as np import torch +from fastsafetensors import SafeTensorsFileLoader, SingleGroup from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm @@ -418,6 +419,37 @@ def safetensors_weights_iterator( yield name, param +def fastsafetensors_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files + using fastsafetensor library.""" + pg = SingleGroup() + if torch.distributed.is_initialized(): + pg = torch.distributed.group.WORLD + + device = torch.device(f'cuda:{pg.rank()}') + weight_files_sub_lists = [ + hf_weights_files[i:i + pg.size()] + for i in range(0, len(hf_weights_files), pg.size()) + ] + + for f_list in weight_files_sub_lists: + # nogds=True DISABLE the NVIDIA GDS support for fastsafetensors + loader = SafeTensorsFileLoader(pg, device, + nogds=True, + debug_log=False) + rank_file_map = {i: [f] for i, f in enumerate(f_list)} + loader.add_filenames(rank_file_map) + fb = loader.copy_files_to_device() + keys = list(fb.key_to_rank_lidx.keys()) + for k in keys: + t = fb.get_tensor(k) + yield k, t + fb.close() + loader.close() + + def runai_safetensors_weights_iterator( hf_weights_files: List[str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: