Skip to content

Commit ff9e387

Browse files
authored
update dockerfile (#981)
1 parent 2cebd42 commit ff9e387

File tree

5 files changed

+72
-63
lines changed

5 files changed

+72
-63
lines changed

docker/Dockerfile

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
FROM nvcr.io/nvidia/tritonserver:24.04-py3-min as base
2-
ARG PYTORCH_VERSION=2.6.0
3-
ARG PYTHON_VERSION=3.9
4-
ARG CUDA_VERSION=12.4
5-
ARG MAMBA_VERSION=23.1.0-1
1+
ARG CUDA_VERSION=12.6.1
2+
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
3+
ARG PYTHON_VERSION=3.10
4+
ARG MAMBA_VERSION=24.7.1-0
65
ARG TARGETPLATFORM
7-
86
ENV PATH=/opt/conda/bin:$PATH \
97
CONDA_PREFIX=/opt/conda
108

@@ -21,7 +19,7 @@ RUN case ${TARGETPLATFORM} in \
2119
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
2220
*) MAMBA_ARCH=x86_64 ;; \
2321
esac && \
24-
curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \
22+
curl -fsSL -o ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \
2523
bash ~/mambaforge.sh -b -p /opt/conda && \
2624
rm ~/mambaforge.sh
2725

@@ -36,11 +34,14 @@ RUN case ${TARGETPLATFORM} in \
3634
WORKDIR /root
3735

3836
COPY ./requirements.txt /lightllm/requirements.txt
39-
RUN pip install -r /lightllm/requirements.txt --no-cache-dir --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124
37+
RUN pip install -U pip
38+
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
39+
40+
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4041

41-
RUN pip install --no-cache-dir https://github.com/ModelTC/flash-attn-3-build/releases/download/v2.7.4.post1/flash_attn-3.0.0b1-cp39-cp39-linux_x86_64.whl
42+
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
4243

43-
RUN pip install --no-cache-dir nvidia-nccl-cu12==2.25.1 # for allreduce hang issues in multinode H100
44+
RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel
4445

4546
COPY . /lightllm
4647
RUN pip install -e /lightllm --no-cache-dir

docker/Dockerfile.deepep

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
FROM nvcr.io/nvidia/tritonserver:24.04-py3-min as base
2-
ARG PYTORCH_VERSION=2.6.0
3-
ARG PYTHON_VERSION=3.9
4-
ARG CUDA_VERSION=12.4
5-
ARG MAMBA_VERSION=23.1.0-1
1+
ARG CUDA_VERSION=12.6.1
2+
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
3+
ARG PYTHON_VERSION=3.10
4+
ARG MAMBA_VERSION=24.7.1-0
65
ARG TARGETPLATFORM
7-
86
ENV PATH=/opt/conda/bin:$PATH \
97
CONDA_PREFIX=/opt/conda
108

@@ -21,7 +19,7 @@ RUN case ${TARGETPLATFORM} in \
2119
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
2220
*) MAMBA_ARCH=x86_64 ;; \
2321
esac && \
24-
curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \
22+
curl -fsSL -o ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \
2523
bash ~/mambaforge.sh -b -p /opt/conda && \
2624
rm ~/mambaforge.sh
2725

@@ -36,39 +34,46 @@ RUN case ${TARGETPLATFORM} in \
3634
WORKDIR /root
3735

3836
COPY ./requirements.txt /lightllm/requirements.txt
39-
RUN pip install -r /lightllm/requirements.txt --no-cache-dir --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124
37+
RUN pip install -U pip
38+
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
4039

41-
RUN pip install --no-cache-dir https://github.com/ModelTC/flash-attn-3-build/releases/download/v2.7.4.post1/flash_attn-3.0.0b1-cp39-cp39-linux_x86_64.whl
40+
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4241

43-
RUN pip install --no-cache-dir nvidia-nccl-cu12==2.25.1 # for allreduce hang issues in multinode H100
42+
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
4443

45-
RUN git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git
46-
RUN cd DeepGEMM && python setup.py install
44+
RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms
45+
RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev
4746

48-
WORKDIR /root
49-
RUN git clone https://github.com/deepseek-ai/DeepEP.git
47+
ENV CUDA_HOME=/usr/local/cuda \
48+
GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/
5049

51-
# NVSHMEM
52-
RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
53-
RUN tar -xf nvshmem_src_3.2.5-1.txz \
54-
&& mv nvshmem_src nvshmem
50+
RUN mkdir -p /tmp/gdrcopy && cd /tmp \
51+
&& git clone https://github.com/NVIDIA/gdrcopy.git -b v2.4.4 \
52+
&& cd gdrcopy/packages \
53+
&& CUDA=/usr/local/cuda ./build-deb-packages.sh \
54+
&& dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \
55+
&& cd / && rm -rf /tmp/gdrcopy
5556

56-
WORKDIR /root/nvshmem
57-
RUN git apply /root/DeepEP/third-party/nvshmem.patch
57+
# Fix DeepEP IBGDA symlink
58+
RUN ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
5859

59-
WORKDIR /root/nvshmem
60-
ENV CUDA_HOME=/usr/local/cuda
61-
RUN NVSHMEM_SHMEM_SUPPORT=0 \
60+
RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz \
61+
&& tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz && mv nvshmem_src nvshmem \
62+
&& cd nvshmem \
63+
&& rm -f /root/nvshmem_src_cuda12-all-all-3.3.9.tar.gz \
64+
&& NVSHMEM_SHMEM_SUPPORT=0 \
6265
NVSHMEM_UCX_SUPPORT=0 \
6366
NVSHMEM_USE_NCCL=0 \
6467
NVSHMEM_MPI_SUPPORT=0 \
6568
NVSHMEM_IBGDA_SUPPORT=1 \
6669
NVSHMEM_PMIX_SUPPORT=0 \
6770
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
6871
NVSHMEM_USE_GDRCOPY=1 \
69-
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90 -DMLX5_lib=/usr/lib/x86_64-linux-gnu/libmlx5.so.1 \
70-
&& cd build \
71-
&& make install -j64
72+
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90 \
73+
&& cmake --build build --target install -j64
74+
75+
ARG DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58
76+
RUN git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..
7277

7378
WORKDIR /root/DeepEP
7479
ENV NVSHMEM_DIR=/root/nvshmem/install

lightllm/models/qwen2_vl/vision_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
ChannelDimension,
4545
ImageInput,
4646
PILImageResampling,
47-
VideoInput,
4847
get_image_size,
4948
infer_channel_dimension_format,
5049
is_scaled_image,
@@ -54,6 +53,7 @@
5453
valid_images,
5554
validate_preprocess_arguments,
5655
)
56+
from transformers.video_utils import VideoInput
5757
from transformers.utils import TensorType, is_vision_available, logging
5858

5959
logger = logging.get_logger(__name__)

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _flash_attention_triton_fwd(
152152

153153
_flash_attn_v3_available = False
154154
try:
155-
from flash_attn_interface import _flash_attn_forward
155+
from sgl_kernel.flash_attn import flash_attn_varlen_func
156156

157157
_flash_attn_v3_available = True
158158

@@ -166,36 +166,43 @@ def flash_attention_v3_fwd(
166166
):
167167
head_dim = q.shape[-1]
168168
softmax_scale = head_dim ** -0.5
169-
_flash_attn_forward(
169+
window_size = (-1, -1)
170+
torch.ops.sgl_kernel.fwd.default(
170171
q,
171172
k,
172173
v,
173-
None,
174-
None, # k_new, v_new
174+
None, # k_new
175+
None, # v_new
176+
None, # qv
175177
o, # out
176178
cu_seqlens,
177179
cu_seqlens,
178-
None, # cu_seqlens_q/k/k_new
179-
None,
180-
None, # seqused_q/k
181-
max_seqlen,
182-
max_seqlen, # max_seqlen_q/k
183-
None,
180+
None, # cu_seqlens_k_new
184181
None,
185-
None, # page_table, kv_batch_idx, leftpad_k,
186182
None,
187-
None, # rotary_cos/sin
183+
max_seqlen,
184+
max_seqlen,
185+
None, # page_table,
186+
None, # kv_batch_idx
187+
None, # leftpad_k
188+
None, # rotary cos
189+
None, # rotary sin
190+
None, # seqlens_rotary
188191
None,
189192
None,
190193
None,
191194
softmax_scale,
192-
False, # causal
193-
window_size=(-1, -1),
194-
softcap=0.0,
195+
False,
196+
window_size[0],
197+
window_size[1],
198+
0.0,
199+
is_rotary_interleaved=False,
200+
scheduler_metadata=None,
195201
num_splits=1,
196202
pack_gqa=None,
197203
sm_margin=0,
198204
)
205+
199206
return
200207

201208
except ImportError:
@@ -205,10 +212,10 @@ def flash_attention_v3_fwd(
205212

206213
def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen):
207214
"""
208-
统一的 Flash Attention 接口。如果 _flash_attn_forward 存在,
209-
则使用 flash_attention_v3_fwd,否则使用 Triton 版本。
215+
统一的 Flash Attention 接口。如果 sgl_kernel 存在,
216+
则使用 sgl_kernel里的接口,否则使用 Triton 版本。
210217
"""
211-
if _flash_attn_v3_available and is_hopper():
218+
if _flash_attn_v3_available and is_hopper() and False:
212219
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen)
213220
else:
214221
_flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen)

requirements.txt

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ multiprocessing-logging==0.3.4
3434
networkx==3.1
3535
ninja==1.11.1
3636
numpy==1.25.1
37-
packaging==23.1
37+
packaging==24.2
3838
pip==23.0.1
3939
pluggy==1.2.0
4040
plumbum==1.8.2
@@ -54,18 +54,15 @@ ruamel.yaml==0.17.32
5454
ruamel.yaml.clib==0.2.7
5555
s3transfer==0.6.1
5656
sentencepiece==0.2.0
57-
setuptools==65.6.3
57+
setuptools==77.0.3
5858
six==1.16.0
5959
sniffio==1.3.0
60-
sympy==1.13.1
6160
sortedcontainers==2.4.0
6261
toolz==0.12.0
63-
torch==2.6.0
64-
torchvision==0.21.0
62+
torch==2.7.1
6563
tqdm==4.65.0
6664
transformers==4.51.2
6765
tokenizers==0.21.1
68-
triton==3.2.0
6966
urllib3==1.26.16
7067
uvicorn==0.19.0
7168
uvloop==0.17.0
@@ -83,9 +80,8 @@ frozendict==2.4.6
8380
atomics==1.0.3
8481
easydict==1.13
8582
gunicorn==23.0.0
86-
vllm==0.8.5
8783
flashinfer-python==0.2.4
88-
sgl-kernel==0.1.4
84+
sgl-kernel==0.2.6
8985
httpx==0.28.1
9086
librosa==0.11.0
9187
cuda_bindings==12.9.0

0 commit comments

Comments
 (0)