diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index fcbddb79c35e..2795548fb339 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -18,6 +18,17 @@ RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(whic ARG COMMON_WORKDIR WORKDIR ${COMMON_WORKDIR} +FROM base AS build_fa +ARG FA_BRANCH="0e60e394" +ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN git clone ${FA_REPO} +RUN cd flash-attention \ + && git checkout ${FA_BRANCH} \ + && git submodule update --init \ + && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install # ----------------------- # vLLM fetch stages