Skip to content

Commit da5ac4a

Browse files
Merge remote-tracking branch 'upstream/main' into develop_IFU_20251118
# Conflicts: # .ci/docker/ci_commit_pins/triton.txt # requirements.txt
2 parents 3d74218 + e2b53ba commit da5ac4a

File tree

1,539 files changed

+52986
-16714
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,539 files changed

+52986
-16714
lines changed

.ci/docker/almalinux/Dockerfile

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
77
ENV LANG en_US.UTF-8
88
ENV LANGUAGE en_US.UTF-8
99

10-
ARG DEVTOOLSET_VERSION=11
10+
ARG DEVTOOLSET_VERSION=13
1111

1212
RUN yum -y update
1313
RUN yum -y install epel-release
1414
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
1515
RUN yum -y install glibc-langpack-en
16-
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
16+
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
1717
# Just add everything as a safe.directory for git since these will be used in multiple places with git
1818
RUN git config --global --add safe.directory '*'
1919
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
@@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh
4141
# Install CUDA
4242
FROM base as cuda
4343
ARG CUDA_VERSION=12.6
44+
ARG DEVTOOLSET_VERSION=13
4445
RUN rm -rf /usr/local/cuda-*
4546
ADD ./common/install_cuda.sh install_cuda.sh
4647
COPY ./common/install_nccl.sh install_nccl.sh
@@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
5051
# Preserve CUDA_VERSION for the builds
5152
ENV CUDA_VERSION=${CUDA_VERSION}
5253
# Make things in our path by default
53-
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
54+
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
55+
5456

5557
FROM cuda as cuda12.6
5658
RUN bash ./install_cuda.sh 12.6
@@ -68,8 +70,22 @@ FROM cuda as cuda13.0
6870
RUN bash ./install_cuda.sh 13.0
6971
ENV DESIRED_CUDA=13.0
7072

71-
FROM ${ROCM_IMAGE} as rocm
73+
FROM ${ROCM_IMAGE} as rocm_base
74+
ARG DEVTOOLSET_VERSION=13
75+
ENV LC_ALL en_US.UTF-8
76+
ENV LANG en_US.UTF-8
77+
ENV LANGUAGE en_US.UTF-8
78+
# Install devtoolset on ROCm base image
79+
RUN yum -y update && \
80+
yum -y install epel-release && \
81+
yum -y install glibc-langpack-en && \
82+
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
83+
RUN git config --global --add safe.directory '*'
84+
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
85+
86+
FROM rocm_base as rocm
7287
ARG PYTORCH_ROCM_ARCH
88+
ARG DEVTOOLSET_VERSION=13
7389
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
7490
ADD ./common/install_mkl.sh install_mkl.sh
7591
RUN bash ./install_mkl.sh && rm install_mkl.sh
@@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
88104

89105
# Final step
90106
FROM ${BASE_TARGET} as final
107+
ARG DEVTOOLSET_VERSION=13
91108
COPY --from=openssl /opt/openssl /opt/openssl
92109
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
93110
COPY --from=conda /opt/conda /opt/conda

.ci/docker/almalinux/build.sh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,7 @@ case ${DOCKER_TAG_PREFIX} in
3636
;;
3737
rocm*)
3838
BASE_TARGET=rocm
39-
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
40-
# add gfx950, gfx115x conditionally starting in ROCm 7.0
41-
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
42-
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
43-
fi
39+
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
4440
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
4541
;;
4642
*)
@@ -63,7 +59,7 @@ docker build \
6359
--target final \
6460
--progress plain \
6561
--build-arg "BASE_TARGET=${BASE_TARGET}" \
66-
--build-arg "DEVTOOLSET_VERSION=11" \
62+
--build-arg "DEVTOOLSET_VERSION=13" \
6763
${EXTRA_BUILD_ARGS} \
6864
-t ${tmp_tag} \
6965
$@ \

.ci/docker/build.sh

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ case "$tag" in
168168
VISION=yes
169169
TRITON=yes
170170
;;
171+
pytorch-linux-jammy-py3.11-clang12)
172+
ANACONDA_PYTHON_VERSION=3.11
173+
CLANG_VERSION=12
174+
VISION=no
175+
TRITON=no
176+
;;
177+
pytorch-linux-jammy-py3.12-clang12)
178+
ANACONDA_PYTHON_VERSION=3.12
179+
CLANG_VERSION=12
180+
VISION=no
181+
TRITON=no
182+
;;
171183
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
172184
if [[ $tag =~ "jammy" ]]; then
173185
ANACONDA_PYTHON_VERSION=3.10
@@ -176,7 +188,7 @@ case "$tag" in
176188
fi
177189
GCC_VERSION=11
178190
VISION=yes
179-
ROCM_VERSION=7.0
191+
ROCM_VERSION=7.1
180192
NINJA_VERSION=1.9.0
181193
TRITON=yes
182194
KATEX=yes
@@ -195,9 +207,9 @@ case "$tag" in
195207
NINJA_VERSION=1.9.0
196208
TRITON=yes
197209
;;
198-
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
210+
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
199211
ANACONDA_PYTHON_VERSION=3.10
200-
GCC_VERSION=11
212+
GCC_VERSION=13
201213
VISION=yes
202214
XPU_VERSION=2025.2
203215
NINJA_VERSION=1.9.0
@@ -248,6 +260,12 @@ case "$tag" in
248260
HALIDE=yes
249261
TRITON=yes
250262
;;
263+
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
264+
CUDA_VERSION=12.8.1
265+
ANACONDA_PYTHON_VERSION=3.12
266+
GCC_VERSION=11
267+
PALLAS=yes
268+
;;
251269
pytorch-linux-jammy-py3.12-triton-cpu)
252270
CUDA_VERSION=12.6
253271
ANACONDA_PYTHON_VERSION=3.12
@@ -261,19 +279,29 @@ case "$tag" in
261279
PYTHON_VERSION=3.10
262280
CUDA_VERSION=12.8.1
263281
;;
264-
pytorch-linux-jammy-aarch64-py3.10-gcc11)
282+
pytorch-linux-jammy-aarch64-py3.10-gcc13)
265283
ANACONDA_PYTHON_VERSION=3.10
266-
GCC_VERSION=11
284+
GCC_VERSION=13
267285
ACL=yes
268286
VISION=yes
269287
OPENBLAS=yes
270288
# snadampal: skipping llvm src build install because the current version
271289
# from pytorch/llvm:9.0.1 is x86 specific
272290
SKIP_LLVM_SRC_BUILD_INSTALL=yes
273291
;;
274-
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
292+
pytorch-linux-jammy-aarch64-py3.10-clang21)
275293
ANACONDA_PYTHON_VERSION=3.10
276-
GCC_VERSION=11
294+
CLANG_VERSION=21
295+
ACL=yes
296+
VISION=yes
297+
OPENBLAS=yes
298+
# snadampal: skipping llvm src build install because the current version
299+
# from pytorch/llvm:9.0.1 is x86 specific
300+
SKIP_LLVM_SRC_BUILD_INSTALL=yes
301+
;;
302+
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
303+
ANACONDA_PYTHON_VERSION=3.10
304+
GCC_VERSION=13
277305
ACL=yes
278306
VISION=yes
279307
OPENBLAS=yes
@@ -359,6 +387,7 @@ docker build \
359387
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
360388
--build-arg "EXECUTORCH=${EXECUTORCH}" \
361389
--build-arg "HALIDE=${HALIDE}" \
390+
--build-arg "PALLAS=${PALLAS}" \
362391
--build-arg "XPU_VERSION=${XPU_VERSION}" \
363392
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
364393
--build-arg "ACL=${ACL:-}" \

.ci/docker/ci_commit_pins/jax.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.8.0
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
<<<<<<< HEAD
12
ac80c4190aa0321f761a08af97e1e1eee41f01d9
3+
=======
4+
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
5+
>>>>>>> upstream/main

.ci/docker/common/install_clang.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
88
# work around ubuntu apt-get conflicts
99
sudo apt-get -y -f install
1010
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
11-
if [[ $CLANG_VERSION == 18 ]]; then
12-
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
11+
if [[ $CLANG_VERSION -ge 18 ]]; then
12+
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
1313
fi
1414
fi
1515

.ci/docker/common/install_gcc.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
77
# Need the official toolchain repo to get alternate packages
88
add-apt-repository ppa:ubuntu-toolchain-r/test
99
apt-get update
10-
apt-get install -y g++-$GCC_VERSION
10+
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION
1111
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
1212
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
1313
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
14-
14+
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
1515

1616
# Cleanup package manager
1717
apt-get autoclean && apt-get clean

.ci/docker/common/install_jax.sh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
6+
7+
# Get the pinned JAX version (same for all CUDA versions)
8+
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
9+
10+
function install_jax_12() {
11+
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
12+
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
13+
14+
# Verify installation
15+
python -c "import jax" # check for errors
16+
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
17+
}
18+
19+
function install_jax_13() {
20+
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
21+
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
22+
23+
# Verify installation
24+
python -c "import jax" # check for errors
25+
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
26+
}
27+
28+
# idiomatic parameter and option handling in sh
29+
while test $# -gt 0
30+
do
31+
case "$1" in
32+
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
33+
;;
34+
13.0|13.0.*) install_jax_13;
35+
;;
36+
*) echo "bad argument $1"; exit 1
37+
;;
38+
esac
39+
shift
40+
done
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/bin/bash
2+
# Script used only in CD pipeline
3+
4+
set -ex
5+
6+
# install dependencies
7+
dnf -y install gmp-devel libmpc-devel texinfo flex bison
8+
9+
cd /usr/local/src
10+
# fetch source for gcc 13
11+
git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0
12+
13+
mkdir -p gcc-13.3.0/build-gomp
14+
cd gcc-13.3.0/build-gomp
15+
16+
# configure gcc build
17+
# I got these flags by:
18+
# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container
19+
# dnf install -y dnf-plugins-core rpmdevtools
20+
# dnf download --source libgomp
21+
# 2. extracting the gcc.spec from the source.
22+
# rpmdev-extract gcc-xx.src.rpm
23+
# 3. extracting optflags and ld_flags from gcc.spec:
24+
# rpm --eval '%{optflags}'
25+
# rpm --eval '%{build_ldflags}'
26+
#
27+
# I had to remove the following flags because they didn't compile for this version of libgomp:
28+
# -Werror=format-security
29+
# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1
30+
# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1
31+
#
32+
# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed.
33+
34+
OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\
35+
' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\
36+
' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\
37+
' -fstack-protector-strong -fasynchronous-unwind-tables'\
38+
' -fstack-clash-protection'
39+
40+
LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now'
41+
42+
CFLAGS="$OPT_FLAGS" \
43+
CXXFLAGS="$OPT_FLAGS" \
44+
LDFLAGS="$LDFLAGS" \
45+
../configure \
46+
--prefix=/usr \
47+
--libdir=/usr/lib64 \
48+
--enable-languages=c,c++ \
49+
--disable-multilib \
50+
--disable-bootstrap \
51+
--enable-libgomp
52+
53+
# only build libgomp
54+
make -j$(nproc) all-target-libgomp
55+
56+
make install-target-libgomp

.ci/docker/common/install_openblas.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
1010

1111
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
1212
OPENBLAS_BUILD_FLAGS="
13+
CC=gcc
1314
NUM_THREADS=128
1415
USE_OPENMP=1
1516
NO_SHARED=0

0 commit comments

Comments
 (0)