11ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.4-complete
22ARG HIPBLASLT_BRANCH="0f5d6c6d"
3+ ARG HIPBLAS_COMMON_BRANCH="9b80ba8e"
34ARG LEGACY_HIPBLASLT_OPTION=
4- ARG RCCL_BRANCH="648a58d"
5- ARG RCCL_REPO="https://github.com/ROCm/rccl"
6- ARG TRITON_BRANCH="5fe38ffd"
5+ ARG TRITON_BRANCH="981e987e"
76ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
8- ARG PYTORCH_BRANCH="13417947 "
9- ARG PYTORCH_VISION_BRANCH="v0.22.0-rc5 "
7+ ARG PYTORCH_BRANCH="295f2ed4 "
8+ ARG PYTORCH_VISION_BRANCH="v0.21.0 "
109ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1110ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1211ARG FA_BRANCH="1a7f4dfa"
1312ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
14- ARG AITER_BRANCH="c4a9ce75 "
13+ ARG AITER_BRANCH="5a77249 "
1514ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1615
1716FROM ${BASE_IMAGE} AS base
@@ -45,25 +44,26 @@ RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython
4544
4645FROM base AS build_hipblaslt
4746ARG HIPBLASLT_BRANCH
47+ ARG HIPBLAS_COMMON_BRANCH
4848# Set to "--legacy_hipblas_direct" for ROCm<=6.2
4949ARG LEGACY_HIPBLASLT_OPTION
50+ RUN git clone https://github.com/ROCm/hipBLAS-common.git
51+ RUN apt-get remove -y hipblaslt && apt-get autoremove -y && apt-get autoclean -y
52+ RUN cd hipBLAS-common \
53+ && git checkout ${HIPBLAS_COMMON_BRANCH} \
54+ && mkdir build \
55+ && cd build \
56+ && cmake .. \
57+ && make package \
58+ && dpkg -i ./*.deb
5059RUN git clone https://github.com/ROCm/hipBLASLt
5160RUN cd hipBLASLt \
5261 && git checkout ${HIPBLASLT_BRANCH} \
5362 && apt-get install -y llvm-dev \
5463 && ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
5564 && cd build/release \
5665 && make package
57- RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/install
58-
59- FROM base AS build_rccl
60- ARG RCCL_BRANCH
61- ARG RCCL_REPO
62- RUN git clone ${RCCL_REPO}
63- RUN cd rccl \
64- && git checkout ${RCCL_BRANCH} \
65- && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
66- RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
66+ RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
6767
6868FROM base AS build_triton
6969ARG TRITON_BRANCH
@@ -119,15 +119,25 @@ RUN cd aiter \
119119RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
120120RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
121121
122+ FROM base AS debs
123+ RUN mkdir /app/debs
124+ RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
125+ cp /install/*.deb /app/debs
126+ RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
127+ cp /install/*.whl /app/debs
128+ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
129+ cp /install/*.whl /app/debs
130+ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
131+ cp /install/*.whl /app/debs
132+ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
133+ cp /install/*.whl /app/debs
134+
122135FROM base AS final
123136RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
124137 dpkg -i /install/*deb \
125- && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
126- && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
127- RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
128- dpkg -i /install/*deb \
129- && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
130- && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
138+ && perl -p -i -e 's/, hipblas-common-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \
139+ && perl -p -i -e 's/, hipblaslt-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \
140+ && perl -p -i -e 's/, hipblaslt \([^)]*?\), /, /g' /var/lib/dpkg/status
131141RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
132142 pip install /install/*.whl
133143RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
@@ -141,8 +151,6 @@ ARG BASE_IMAGE
141151ARG HIPBLAS_COMMON_BRANCH
142152ARG HIPBLASLT_BRANCH
143153ARG LEGACY_HIPBLASLT_OPTION
144- ARG RCCL_BRANCH
145- ARG RCCL_REPO
146154ARG TRITON_BRANCH
147155ARG TRITON_REPO
148156ARG PYTORCH_BRANCH
@@ -154,10 +162,9 @@ ARG FA_REPO
154162ARG AITER_BRANCH
155163ARG AITER_REPO
156164RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
165+ && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
157166 && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
158167 && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
159- && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
160- && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
161168 && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
162169 && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
163170 && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
0 commit comments