@@ -59,6 +59,8 @@ apt-get update -yq
5959cuda_pkg=" cuda-libraries-${CUDA_VERSION/ ./ -} "
6060nvtx_pkg=" cuda-nvtx-${CUDA_VERSION/ ./ -} "
6161toolkit_pkg=" cuda-toolkit-${CUDA_VERSION/ ./ -} "
62+ major_cudnn_version=$( echo " ${CUDNN_VERSION} " | cut -d ' .' -f 1)
63+ major_cuda_version=$( echo " ${CUDA_VERSION} " | cut -d ' .' -f 1)
6264if ! apt-cache show " $cuda_pkg " ; then
6365 echo " The requested version of CUDA is not available: CUDA $CUDA_VERSION "
6466 exit 1
@@ -69,7 +71,15 @@ apt-get install -yq "$cuda_pkg"
6971
7072if [ " $INSTALL_CUDNN " = " true" ]; then
7173 # Ensure that the requested version of cuDNN is available AND compatible
72- cudnn_pkg_version=" libcudnn8=${CUDNN_VERSION} -1+cuda${CUDA_VERSION} "
74+ # if major cudnn version is 9, then we need to install libcudnn9-cuda-<major_version> package
75+ # else we need to install libcudnn8-cuda-<major_version> package
76+ if [[ $major_cudnn_version -ge " 9" ]]
77+ then
78+ cudnn_pkg_version=" libcudnn9-cuda-${major_cuda_version} =${CUDNN_VERSION} -1"
79+ else
80+ cudnn_pkg_version=" libcudnn8=${CUDNN_VERSION} -1+cuda${CUDA_VERSION} "
81+ fi
82+
7383 if ! apt-cache show " $cudnn_pkg_version " ; then
7484 echo " The requested version of cuDNN is not available: cuDNN $CUDNN_VERSION for CUDA $CUDA_VERSION "
7585 exit 1
8191
8292if [ " $INSTALL_CUDNNDEV " = " true" ]; then
8393 # Ensure that the requested version of cuDNN development package is available AND compatible
84- cudnn_dev_pkg_version=" libcudnn8-dev=${CUDNN_VERSION} -1+cuda${CUDA_VERSION} "
94+ if [[ $major_cudnn_version -ge " 9" ]]
95+ then
96+ cudnn_dev_pkg_version=" libcudnn9-dev-cuda-${major_cuda_version} =${CUDNN_VERSION} -1"
97+ else
98+ cudnn_dev_pkg_version=" libcudnn8-dev=${CUDNN_VERSION} -1+cuda${CUDA_VERSION} "
99+ fi
100+
85101 if ! apt-cache show " $cudnn_dev_pkg_version " ; then
86102 echo " The requested version of cuDNN development package is not available: cuDNN $CUDNN_VERSION for CUDA $CUDA_VERSION "
87103 exit 1
0 commit comments