Skip to content

Commit a6fab82

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE]: Fix NVSHMEM builds, add missing 12.9 dependency and update to latest for 2.8RC (pytorch#157453)
Fixed our bad builds of nvshmem, (we were not building or testing before) and also updates to the latest version. Newest versions has critical support for things that would actually make it useful, like bfloat16 and float16 support. This is a proper fix for: pytorch#157411 Pull Request resolved: pytorch#157453 Approved by: https://github.com/kwen2501, https://github.com/atalman
1 parent dd3e717 commit a6fab82

File tree

5 files changed

+78
-32
lines changed

5 files changed

+78
-32
lines changed

.ci/docker/common/install_cuda.sh

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ else
1010
arch_path='sbsa'
1111
fi
1212

13+
NVSHMEM_VERSION=3.3.9
14+
1315
function install_cuda {
1416
version=$1
1517
runfile=$2
@@ -40,13 +42,52 @@ function install_cudnn {
4042
rm -rf tmp_cudnn
4143
}
4244

45+
function install_nvshmem {
46+
cuda_major_version=$1 # e.g. "12"
47+
nvshmem_version=$2 # e.g. "3.3.9"
48+
49+
case "${arch_path}" in
50+
sbsa)
51+
dl_arch="aarch64"
52+
;;
53+
x86_64)
54+
dl_arch="x64"
55+
;;
56+
*)
57+
dl_arch="${arch}"
58+
;;
59+
esac
60+
61+
tmpdir="tmp_nvshmem"
62+
mkdir -p "${tmpdir}" && cd "${tmpdir}"
63+
64+
# nvSHMEM license: https://docs.nvidia.com/nvshmem/api/sla.html
65+
filename="libnvshmem_cuda${cuda_major_version}-linux-${arch_path}-${nvshmem_version}"
66+
url="https://developer.download.nvidia.com/compute/redist/nvshmem/${nvshmem_version}/builds/cuda${cuda_major_version}/txz/agnostic/${dl_arch}/${filename}.tar.gz"
67+
68+
# download, unpack, install
69+
wget -q "${url}"
70+
tar xf "${filename}.tar.gz"
71+
cp -a "libnvshmem/include/"* /usr/local/include/
72+
cp -a "libnvshmem/lib/"* /usr/local/lib/
73+
74+
# cleanup
75+
cd ..
76+
rm -rf "${tmpdir}"
77+
78+
echo "nvSHMEM ${nvshmem_version} for CUDA ${cuda_major_version} (${arch_path}) installed."
79+
}
80+
81+
4382
function install_126 {
4483
CUDNN_VERSION=9.10.2.21
45-
echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.7.1"
84+
echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
4685
install_cuda 12.6.3 cuda_12.6.3_560.35.05_linux
4786

4887
install_cudnn 12 $CUDNN_VERSION
4988

89+
install_nvshmem 12 $NVSHMEM_VERSION
90+
5091
CUDA_VERSION=12.6 bash install_nccl.sh
5192

5293
CUDA_VERSION=12.6 bash install_cusparselt.sh
@@ -56,13 +97,15 @@ function install_126 {
5697

5798
function install_129 {
5899
CUDNN_VERSION=9.10.2.21
59-
echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.7.1"
100+
echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
60101
# install CUDA 12.9.1 in the same container
61102
install_cuda 12.9.1 cuda_12.9.1_575.57.08_linux
62103

63104
# cuDNN license: https://developer.nvidia.com/cudnn/license_agreement
64105
install_cudnn 12 $CUDNN_VERSION
65106

107+
install_nvshmem 12 $NVSHMEM_VERSION
108+
66109
CUDA_VERSION=12.9 bash install_nccl.sh
67110

68111
CUDA_VERSION=12.9 bash install_cusparselt.sh
@@ -106,13 +149,15 @@ function prune_126 {
106149

107150
function install_128 {
108151
CUDNN_VERSION=9.8.0.87
109-
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.7.1"
152+
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
110153
# install CUDA 12.8.1 in the same container
111154
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux
112155

113156
# cuDNN license: https://developer.nvidia.com/cudnn/license_agreement
114157
install_cudnn 12 $CUDNN_VERSION
115158

159+
install_nvshmem 12 $NVSHMEM_VERSION
160+
116161
CUDA_VERSION=12.8 bash install_nccl.sh
117162

118163
CUDA_VERSION=12.8 bash install_cusparselt.sh

.github/scripts/generate_binary_build_matrix.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | "
5555
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
5656
"nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
57-
"nvidia-nvshmem-cu12==3.2.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
57+
"nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | "
5858
"nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | "
5959
"nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | "
6060
"nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64'"
@@ -71,7 +71,7 @@
7171
"nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | "
7272
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
7373
"nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
74-
"nvidia-nvshmem-cu12==3.2.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
74+
"nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | "
7575
"nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | "
7676
"nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | "
7777
"nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64'"
@@ -88,6 +88,7 @@
8888
"nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
8989
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
9090
"nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
91+
"nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | "
9192
"nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
9293
"nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
9394
"nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"

0 commit comments

Comments
 (0)