diff --git a/docker/Dockerfile.ci.deps b/docker/Dockerfile.ci.deps new file mode 100644 index 000000000..f9683eea5 --- /dev/null +++ b/docker/Dockerfile.ci.deps @@ -0,0 +1,64 @@ +# TE CI Dockerfile +ARG BASE_DOCKER=rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.5.1 +FROM $BASE_DOCKER +WORKDIR /var/lib/jenkins + +RUN apt --fix-broken install -y +RUN apt update \ + && apt install -y nano wget ninja-build \ + && apt install -y python3 python3-pip git \ + && apt install -y sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev + +RUN python3 -m pip install --upgrade pip +RUN pip install cmake setuptools wheel +RUN pip install ipython pytest fire pydantic pybind11 + +# RUN pip uninstall -y torch + +RUN apt install -y libzstd-dev +RUN apt install -y libibverbs-dev + +ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer +ENV PATH=$PATH:/opt/rocm/bin: +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/: + +# Install pytorch +# ARG PYTORCH_COMMIT="f929e0d602a71aa393ca2e6097674b210bdf321c" +ARG PYTORCH_ROCM_ARCH=gfx942 +RUN rm -fr pytorch \ + && git clone https://github.com/pytorch/pytorch \ + && cd pytorch \ + && git fetch origin ${PYTORCH_COMMIT} \ + && git checkout -q ${PYTORCH_COMMIT} \ + && git submodule update --recursive --init \ + && ./tools/amd_build/build_amd.py \ + && BUILD_TEST=0 python3 setup.py install \ + && cd .. + +# Install flash-attention +ENV GPU_ARCHS=${PYTORCH_ROCM_ARCH} +RUN git clone https://github.com/ROCm/flash-attention.git \ + && cd flash-attention \ + && git checkout v2.7.3-cktile \ + && pip install . \ + && cd .. + +# Install jax +ARG JAX_COMMIT="58e53c664a30015eac865d57b4987827460d67b0" +ARG XLA_COMMIT="fe4a1ec96238c765874ebc76f17184df0d2c7b1f" +RUN git clone https://github.com/ROCm/xla.git && cd xla && git fetch origin ${XLA_COMMIT} && git checkout -q ${XLA_COMMIT} && cd .. \ + && git clone https://github.com/ROCm/jax.git && cd jax && git fetch origin ${JAX_COMMIT} && git checkout -q ${JAX_COMMIT} \ + && python3 ./build/build.py --enable_rocm \ + --build_gpu_plugin \ + --use_clang=true \ + --clang_path=/opt/rocm-6.4.0/lib/llvm/bin/clang \ + --gpu_plugin_rocm_version=60 \ + --rocm_path=/opt/rocm-6.4.0/ \ + --rocm_amdgpu_targets=${GPU_ARCH} \ + --bazel_options=--override_repository=xla=/var/lib/jenkins/xla \ + && pip install jax==0.4.35 \ + && python3 setup.py develop --user && python3 -m pip install dist/*.whl \ + && pip install jax==0.4.35 + +WORKDIR /workspace/ +CMD ["/bin/bash"] \ No newline at end of file