Skip to content

Commit 54816c1

Browse files
committed
build(torch): Restore arch filtering logic for CUDA versions below 12.8
1 parent b454bd4 commit 54816c1

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

torch/Dockerfile

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,13 @@ RUN --mount=type=bind,from=triton-downloader,source=/git/triton,target=triton/,r
360360

361361
ARG BUILD_TORCH_VERSION
362362
ENV TORCH_VERSION=$BUILD_TORCH_VERSION
363-
# Filter out the 10.0 & 12.0 arches on CUDA versions != 12.8 and != 12.9
364-
# Note: commented out for testing CUDA 13 builds
365-
#ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0 12.0/}||${TORCH_CUDA_ARCH_LIST}"
366-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#12.9.?}"
367-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
368-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
369-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
363+
# Filter out the 10.0 & 12.0 arches on CUDA versions < 12.8
364+
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0 12.0/}||${TORCH_CUDA_ARCH_LIST}"
365+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#12.9.?}"
366+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#13.?.?}"
367+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
368+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
369+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
370370

371371
RUN printf 'Arch: %s\nTORCH_CUDA_ARCH_LIST=%s\n' "$(uname -m)" "${TORCH_CUDA_ARCH_LIST}"
372372

@@ -694,13 +694,13 @@ ARG BUILD_TORCH_AUDIO_VERSION
694694
ENV TORCH_VERSION=$BUILD_TORCH_VERSION
695695
ENV TORCH_VISION_VERSION=$BUILD_TORCH_VISION_VERSION
696696
ENV TORCH_AUDIO_VERSION=$BUILD_TORCH_AUDIO_VERSION
697-
# Filter out the 10.0 & 12.0 arches on CUDA versions != 12.8 and != 12.9
698-
# Note: commented out for testing CUDA 13 builds
699-
#ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0 12.0/}||${TORCH_CUDA_ARCH_LIST}"
700-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#12.9.?}"
701-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
702-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
703-
#ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
697+
# Filter out the 10.0 & 12.0 arches on CUDA versions < 12.8
698+
ENV TORCH_CUDA_ARCH_LIST="${CUDA_VERSION##12.8.*}||${TORCH_CUDA_ARCH_LIST/ 10.0 12.0/}||${TORCH_CUDA_ARCH_LIST}"
699+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#12.9.?}"
700+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#13.?.?}"
701+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#||*||}"
702+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST%||*}"
703+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST#*||}"
704704

705705
COPY --link --chmod=755 install_cudnn.sh /tmp/install_cudnn.sh
706706
# - libnvjitlink-X-Y only exists for CUDA versions >= 12-0.

0 commit comments

Comments
 (0)