11#! /bin/bash
22
3+ verlte () {
4+ [ " $1 " = " $2 " ] && return 1 || [ " $2 " = " $( echo -e " $1 \n$2 " | sort -V | head -n1) " ]
5+ }
6+
7+ if [ -f /usr/local/cuda/compat/libcuda.so.1 ]; then
8+ CUDA_COMPAT_MAX_DRIVER_VERSION=$( readlink /usr/local/cuda/compat/libcuda.so.1 | cut -d" ." -f 3-)
9+ echo " CUDA compat package requires Nvidia driver ≤${CUDA_COMPAT_MAX_DRIVER_VERSION} "
10+ cat /proc/driver/nvidia/version
11+ NVIDIA_DRIVER_VERSION=$( sed -n ' s/^NVRM.*Kernel Module *\([0-9.]*\).*$/\1/p' /proc/driver/nvidia/version 2> /dev/null || true)
12+ echo " Current installed Nvidia driver version is ${NVIDIA_DRIVER_VERSION} "
13+ if [ $( verlte " $CUDA_COMPAT_MAX_DRIVER_VERSION " " $NVIDIA_DRIVER_VERSION " ) ]; then
14+ echo " Setup CUDA compatibility libs path to LD_LIBRARY_PATH"
15+ export LD_LIBRARY_PATH=/usr/local/cuda/compat:$LD_LIBRARY_PATH
16+ echo $LD_LIBRARY_PATH
17+ else
18+ echo " Skip CUDA compat libs setup as newer Nvidia driver is installed"
19+ fi
20+ else
21+ echo " Skip CUDA compat libs setup as package not found"
22+ fi
23+
324if [[ -z " ${HF_MODEL_ID} " ]]; then
425 echo " HF_MODEL_ID must be set"
526 exit 1
@@ -15,9 +36,37 @@ if ! command -v nvidia-smi &> /dev/null; then
1536 exit 1
1637fi
1738
39+ # Query GPU name using nvidia-smi
40+ gpu_name=$( nvidia-smi --query-gpu=gpu_name --format=csv | awk ' NR==2' )
41+ if [ $? -ne 0 ]; then
42+ echo " Error: $gpu_name "
43+ echo " Query gpu_name failed"
44+ else
45+ echo " Query gpu_name succeeded. Printing output: $gpu_name "
46+ fi
47+
48+ # Function to get compute capability based on GPU name
49+ get_compute_cap () {
50+ gpu_name=" $1 "
51+
52+ # Check if the GPU name contains "A10G"
53+ if [[ " $gpu_name " == * " A10G" * ]]; then
54+ echo " 86"
55+ # Check if the GPU name contains "A100"
56+ elif [[ " $gpu_name " == * " A100" * ]]; then
57+ echo " 80"
58+ # Check if the GPU name contains "H100"
59+ elif [[ " $gpu_name " == * " H100" * ]]; then
60+ echo " 90"
61+ else
62+ echo " 80" # Default compute capability
63+ fi
64+ }
65+
1866if [[ -z " ${CUDA_COMPUTE_CAP} " ]]
1967then
20- compute_cap=$( nvidia-smi --query-gpu=compute_cap --format=csv | sed -n ' 2p' | sed ' s/\.//g' )
68+ compute_cap=$( get_compute_cap " $gpu_name " )
69+ echo " the compute_cap is $compute_cap "
2170else
2271 compute_cap=$CUDA_COMPUTE_CAP
2372fi
0 commit comments