Skip to content

Commit 97ebf9d

Browse files
committed
feat(gpu): Add robust proxy support for driver installation
This PR introduces comprehensive HTTP/S proxy support for the GPU driver installation script, enabling its use in environments with restricted internet egress, such as those using Secure Web Proxy. The `set_proxy` function, controlled by the `http-proxy` and new `http-proxy-pem-uri` metadata attributes, now configures APT, GPG, Java, pip, and Conda to route traffic through the specified proxy. If a PEM certificate URI is provided, the certificate is installed into the OS, Conda, and Java trust stores. The script now correctly handles the proxy scheme (HTTP vs HTTPS) based on the presence of the `http-proxy-pem-uri` metadata. This change was validated in a development environment where all internet access was routed through an explicit proxy. Additional changes: - `README.md` updated to document the new `http-proxy-pem-uri` metadata option and clarify `http-proxy` usage. - GCS caching for the NVIDIA driver is checked earlier to avoid unnecessary HEAD requests to the NVIDIA CDN. - `configure_dkms_certs` is now more idempotent. - Spark RAPIDS versions and repository URL aligned with `spark-rapids/spark-rapids.sh` as part of a move towards a unified GPU/RAPIDS installation script. - Switched to using `/sys/bus/pci/devices/*/uevent` for GPU detection to remove dependency on pciutils - Moved `set_proxy` call earlier in `prepare_to_install`. - Refactored `no_proxy` and `nvcc_gencode` list generation.
1 parent 2eb939b commit 97ebf9d

File tree

2 files changed

+235
-30
lines changed

2 files changed

+235
-30
lines changed

gpu/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,18 @@ sometimes found in the "building from source" sections.
225225
modulus md5sum of the files referenced by both the private and
226226
public secret names.
227227

228+
- `http-proxy: <HOST>:<PORT>` - Optional. The address of an HTTP
229+
proxy to use for internet egress. The script will configure `apt`,
230+
`curl`, `gsutil`, `pip`, `java`, and `gpg` to use this proxy.
231+
232+
- `http-proxy-pem-uri: <GS_PATH>` - Optional. A `gs://` path to the
233+
PEM-encoded certificate file used by the proxy specified in
234+
`http-proxy`. This is needed if the proxy uses TLS and its
235+
certificate is not already trusted by the cluster's default trust
236+
store (e.g., if it's a self-signed certificate or signed by an
237+
internal CA). The script will install this certificate into the
238+
system and Java trust stores.
239+
228240
#### Loading built kernel module
229241

230242
For platforms which do not have pre-built binary kernel drivers, the

gpu/install_gpu_driver.sh

Lines changed: 223 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,47 @@ function set_driver_version() {
266266
export DRIVER_VERSION DRIVER
267267

268268
gpu_driver_url="${nv_xf86_x64_base}/${DRIVER_VERSION}/NVIDIA-Linux-x86_64-${DRIVER_VERSION}.run"
269-
if ! curl ${curl_retry_args} --head "${gpu_driver_url}" | grep -E -q 'HTTP.*200' ; then
270-
echo "No NVIDIA driver exists for DRIVER_VERSION=${DRIVER_VERSION}"
271-
exit 1
269+
270+
# GCS Cache Check Logic
271+
local driver_filename
272+
driver_filename=$(basename "${gpu_driver_url}")
273+
local gcs_cache_path="${pkg_bucket}/nvidia/${driver_filename}"
274+
275+
echo "Checking for cached NVIDIA driver at: ${gcs_cache_path}"
276+
277+
if ! gsutil -q stat "${gcs_cache_path}"; then
278+
echo "Driver not found in GCS cache. Validating URL: ${gpu_driver_url}"
279+
# Use curl to check if the URL is valid (HEAD request)
280+
if curl -sSLfI --connect-timeout 10 --max-time 30 "${gpu_driver_url}" 2>/dev/null | grep -E -q 'HTTP.*200'; then
281+
echo "NVIDIA URL is valid. Downloading to cache..."
282+
local temp_driver_file="${tmpdir}/${driver_filename}"
283+
284+
# Download the file
285+
echo "Downloading from ${gpu_driver_url} to ${temp_driver_file}"
286+
if curl -sSLf -o "${temp_driver_file}" "${gpu_driver_url}"; then
287+
echo "Download complete. Uploading to ${gcs_cache_path}"
288+
# Upload to GCS
289+
if gsutil cp "${temp_driver_file}" "${gcs_cache_path}"; then
290+
echo "Successfully cached to GCS."
291+
rm -f "${temp_driver_file}"
292+
else
293+
echo "ERROR: Failed to upload driver to GCS: ${gcs_cache_path}"
294+
rm -f "${temp_driver_file}"
295+
exit 1
296+
fi
297+
else
298+
echo "ERROR: Failed to download driver from NVIDIA: ${gpu_driver_url}"
299+
rm -f "${temp_driver_file}" # File might not exist if curl failed early
300+
exit 1
301+
fi
302+
else
303+
echo "ERROR: NVIDIA driver URL is not valid or accessible: ${gpu_driver_url}"
304+
exit 1
305+
fi
306+
else
307+
echo "Driver found in GCS cache: ${gcs_cache_path}"
272308
fi
309+
# End of GCS Cache Check Logic
273310
}
274311

275312
function set_cudnn_version() {
@@ -673,14 +710,19 @@ function install_nvidia_nccl() {
673710
# Ada: SM_89, compute_89
674711
# Hopper: SM_90,SM_90a compute_90,compute_90a
675712
# Blackwell: SM_100, compute_100
676-
NVCC_GENCODE="-gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_72,code=sm_72"
677-
NVCC_GENCODE="${NVCC_GENCODE} -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86"
713+
local nvcc_gencode=("-gencode=arch=compute_70,code=sm_70" "-gencode=arch=compute_72,code=sm_72"
714+
"-gencode=arch=compute_80,code=sm_80" "-gencode=arch=compute_86,code=sm_86")
715+
678716
if version_gt "${CUDA_VERSION}" "11.6" ; then
679-
NVCC_GENCODE="${NVCC_GENCODE} -gencode=arch=compute_87,code=sm_87" ; fi
717+
nvcc_gencode+=("-gencode=arch=compute_87,code=sm_87")
718+
fi
680719
if version_ge "${CUDA_VERSION}" "11.8" ; then
681-
NVCC_GENCODE="${NVCC_GENCODE} -gencode=arch=compute_89,code=sm_89" ; fi
720+
nvcc_gencode+=("-gencode=arch=compute_89,code=sm_89")
721+
fi
682722
if version_ge "${CUDA_VERSION}" "12.0" ; then
683-
NVCC_GENCODE="${NVCC_GENCODE} -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_90a,code=compute_90a" ; fi
723+
nvcc_gencode+=("-gencode=arch=compute_90,code=sm_90" "-gencode=arch=compute_90a,code=compute_90a")
724+
fi
725+
NVCC_GENCODE="${nvcc_gencode[*]}"
684726

685727
if is_debuntu ; then
686728
# These packages are required to build .deb packages from source
@@ -866,6 +908,7 @@ function configure_dkms_certs() {
866908
echo "No signing secret provided. skipping";
867909
return 0
868910
fi
911+
if [[ -f "${mok_der}" ]] ; then return 0; fi
869912

870913
mkdir -p "${CA_TMPDIR}"
871914

@@ -1418,6 +1461,11 @@ function install_gpu_agent() {
14181461
"${python_interpreter}" -m venv "${venv}"
14191462
(
14201463
source "${venv}/bin/activate"
1464+
if [[ -v METADATA_HTTP_PROXY_PEM_URI ]]; then
1465+
export REQUESTS_CA_BUNDLE="${trusted_pem_path}"
1466+
pip install pip-system-certs
1467+
unset REQUESTS_CA_BUNDLE
1468+
fi
14211469
python3 -m pip install --upgrade pip
14221470
execute_with_retries python3 -m pip install -r "${install_dir}/requirements.txt"
14231471
)
@@ -1725,7 +1773,7 @@ function mark_incomplete() {
17251773
function install_dependencies() {
17261774
is_complete install-dependencies && return 0
17271775

1728-
pkg_list="pciutils screen"
1776+
pkg_list="screen"
17291777
if is_debuntu ; then execute_with_retries apt-get -y -q install ${pkg_list}
17301778
elif is_rocky ; then execute_with_retries dnf -y -q install ${pkg_list} ; fi
17311779
mark_complete install-dependencies
@@ -1837,7 +1885,7 @@ function main() {
18371885
configure_yarn_resources
18381886

18391887
# Detect NVIDIA GPU
1840-
if (lspci | grep -q NVIDIA); then
1888+
if (grep -h -i PCI_ID=10DE /sys/bus/pci/devices/*/uevent); then
18411889
# if this is called without the MIG script then the drivers are not installed
18421890
migquery_result="$(nvsmi --query-gpu=mig.mode.current --format=csv,noheader)"
18431891
if [[ "${migquery_result}" == "[N/A]" ]] ; then migquery_result="" ; fi
@@ -2154,19 +2202,161 @@ function set_proxy(){
21542202

21552203
if [[ -z "${METADATA_HTTP_PROXY}" ]] ; then return ; fi
21562204

2157-
export http_proxy="${METADATA_HTTP_PROXY}"
2158-
export https_proxy="${METADATA_HTTP_PROXY}"
2159-
export HTTP_PROXY="${METADATA_HTTP_PROXY}"
2160-
export HTTPS_PROXY="${METADATA_HTTP_PROXY}"
2161-
no_proxy="localhost,127.0.0.0/8,::1,metadata.google.internal,169.254.169.254"
2162-
local no_proxy_svc
2163-
for no_proxy_svc in compute secretmanager dns servicedirectory logging \
2164-
bigquery composer pubsub bigquerydatatransfer dataflow \
2165-
storage datafusion ; do
2166-
no_proxy="${no_proxy},${no_proxy_svc}.googleapis.com"
2205+
no_proxy_list=("localhost" "127.0.0.0/8" "::1" "metadata.google.internal" "169.254.169.254")
2206+
2207+
services=( compute secretmanager dns servicedirectory networkmanagement
2208+
bigquery composer pubsub bigquerydatatransfer networkservices
2209+
storage datafusion dataproc certificatemanager networksecurity
2210+
dataflow privateca logging )
2211+
2212+
for svc in "${services[@]}"; do
2213+
no_proxy_list+=("${svc}.googleapis.com")
21672214
done
21682215

2216+
no_proxy="$( IFS=',' ; echo "${no_proxy_list[*]}" )"
2217+
2218+
export http_proxy="http://${METADATA_HTTP_PROXY}"
2219+
export https_proxy="http://${METADATA_HTTP_PROXY}"
2220+
export no_proxy
2221+
export HTTP_PROXY="http://${METADATA_HTTP_PROXY}"
2222+
export HTTPS_PROXY="http://${METADATA_HTTP_PROXY}"
21692223
export NO_PROXY="${no_proxy}"
2224+
2225+
# configure gcloud
2226+
gcloud config set proxy/type http
2227+
gcloud config set proxy/address "${METADATA_HTTP_PROXY%:*}"
2228+
gcloud config set proxy/port "${METADATA_HTTP_PROXY#*:}"
2229+
2230+
# add proxy environment variables to /etc/environment
2231+
grep http_proxy /etc/environment || echo "http_proxy=${http_proxy}" >> /etc/environment
2232+
grep https_proxy /etc/environment || echo "https_proxy=${https_proxy}" >> /etc/environment
2233+
grep no_proxy /etc/environment || echo "no_proxy=${no_proxy}" >> /etc/environment
2234+
grep HTTP_PROXY /etc/environment || echo "HTTP_PROXY=${HTTP_PROXY}" >> /etc/environment
2235+
grep HTTPS_PROXY /etc/environment || echo "HTTPS_PROXY=${HTTPS_PROXY}" >> /etc/environment
2236+
grep NO_PROXY /etc/environment || echo "NO_PROXY=${NO_PROXY}" >> /etc/environment
2237+
2238+
local pkg_proxy_conf_file
2239+
if is_debuntu ; then
2240+
# configure Apt to use the proxy:
2241+
pkg_proxy_conf_file="/etc/apt/apt.conf.d/99proxy"
2242+
cat > "${pkg_proxy_conf_file}" <<EOF
2243+
Acquire::http::Proxy "http://${METADATA_HTTP_PROXY}";
2244+
Acquire::https::Proxy "http://${METADATA_HTTP_PROXY}";
2245+
EOF
2246+
elif is_rocky ; then
2247+
pkg_proxy_conf_file="/etc/dnf/dnf.conf"
2248+
2249+
touch "${pkg_proxy_conf_file}"
2250+
2251+
if grep -q "^proxy=" "${pkg_proxy_conf_file}"; then
2252+
sed -i.bak "s@^proxy=.*@proxy=${HTTP_PROXY}@" "${pkg_proxy_conf_file}"
2253+
elif grep -q "^\[main\]" "${pkg_proxy_conf_file}"; then
2254+
sed -i.bak "/^\[main\]/a proxy=${HTTP_PROXY}" "${pkg_proxy_conf_file}"
2255+
else
2256+
local TMP_FILE=$(mktemp)
2257+
printf "[main]\nproxy=%s\n" "${HTTP_PROXY}" > "${TMP_FILE}"
2258+
2259+
cat "${TMP_FILE}" "${pkg_proxy_conf_file}" > "${pkg_proxy_conf_file}".new
2260+
mv "${pkg_proxy_conf_file}".new "${pkg_proxy_conf_file}"
2261+
2262+
rm "${TMP_FILE}"
2263+
fi
2264+
else
2265+
echo "unknown OS"
2266+
exit 1
2267+
fi
2268+
# configure gpg to use the proxy:
2269+
if ! grep 'keyserver-options http-proxy' /etc/gnupg/dirmngr.conf ; then
2270+
mkdir -p /etc/gnupg
2271+
cat >> /etc/gnupg/dirmngr.conf <<EOF
2272+
keyserver-options http-proxy=http://${METADATA_HTTP_PROXY}
2273+
EOF
2274+
fi
2275+
2276+
# Install the HTTPS proxy's certificate in the system and Java trust databases
2277+
METADATA_HTTP_PROXY_PEM_URI="$(get_metadata_attribute http-proxy-pem-uri '')"
2278+
2279+
if [[ -z "${METADATA_HTTP_PROXY_PEM_URI}" ]] ; then return ; fi
2280+
if [[ ! "${METADATA_HTTP_PROXY_PEM_URI}" =~ ^gs ]] ; then echo "http-proxy-pem-uri value should start with gs://" ; exit 1 ; fi
2281+
2282+
local trusted_pem_dir
2283+
# Add this certificate to the OS trust database
2284+
# When proxy cert is provided, speak to the proxy over https
2285+
if is_debuntu ; then
2286+
trusted_pem_dir="/usr/local/share/ca-certificates"
2287+
mkdir -p "${trusted_pem_dir}"
2288+
proxy_ca_pem="${trusted_pem_dir}/proxy_ca.crt"
2289+
gsutil cp "${METADATA_HTTP_PROXY_PEM_URI}" "${proxy_ca_pem}"
2290+
update-ca-certificates
2291+
trusted_pem_path="/etc/ssl/certs/ca-certificates.crt"
2292+
sed -i -e 's|http://|https://|' "${pkg_proxy_conf_file}"
2293+
elif is_rocky ; then
2294+
trusted_pem_dir="/etc/pki/ca-trust/source/anchors"
2295+
mkdir -p "${trusted_pem_dir}"
2296+
proxy_ca_pem="${trusted_pem_dir}/proxy_ca.crt"
2297+
gsutil cp "${METADATA_HTTP_PROXY_PEM_URI}" "${proxy_ca_pem}"
2298+
update-ca-trust
2299+
trusted_pem_path="/etc/ssl/certs/ca-bundle.crt"
2300+
sed -i -e 's|^proxy=http://|proxy=https://|' "${pkg_proxy_conf_file}"
2301+
else
2302+
echo "unknown OS"
2303+
exit 1
2304+
fi
2305+
2306+
# configure gcloud to respect proxy ca cert
2307+
#gcloud config set core/custom_ca_certs_file "${proxy_ca_pem}"
2308+
2309+
ca_subject="$(openssl crl2pkcs7 -nocrl -certfile "${proxy_ca_pem}" | openssl pkcs7 -print_certs -noout | grep ^subject)"
2310+
# Verify that the proxy certificate is trusted
2311+
local output
2312+
output=$(echo | openssl s_client \
2313+
-connect "${METADATA_HTTP_PROXY}" \
2314+
-proxy "${METADATA_HTTP_PROXY}" \
2315+
-CAfile "${proxy_ca_pem}") || {
2316+
echo "proxy certificate verification failed"
2317+
echo "${output}"
2318+
exit 1
2319+
}
2320+
output=$(echo | openssl s_client \
2321+
-connect "${METADATA_HTTP_PROXY}" \
2322+
-proxy "${METADATA_HTTP_PROXY}" \
2323+
-CAfile "${trusted_pem_path}") || {
2324+
echo "proxy ca certificate not included in system bundle"
2325+
echo "${output}"
2326+
exit 1
2327+
}
2328+
output=$(curl --verbose -fsSL --retry-connrefused --retry 10 --retry-max-time 30 --head "https://google.com" 2>&1)|| {
2329+
echo "curl rejects proxy configuration"
2330+
echo "${curl_output}"
2331+
exit 1
2332+
}
2333+
output=$(curl --verbose -fsSL --retry-connrefused --retry 10 --retry-max-time 30 --head "https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run" 2>&1)|| {
2334+
echo "curl rejects proxy configuration"
2335+
echo "${output}"
2336+
exit 1
2337+
}
2338+
2339+
# Instruct conda to use the system certificate
2340+
echo "Attempting to install pip-system-certs using the proxy certificate..."
2341+
export REQUESTS_CA_BUNDLE="${trusted_pem_path}"
2342+
pip install pip-system-certs
2343+
unset REQUESTS_CA_BUNDLE
2344+
2345+
# For the binaries bundled with conda, append our certificate to the bundle
2346+
openssl crl2pkcs7 -nocrl -certfile /opt/conda/default/ssl/cacert.pem | openssl pkcs7 -print_certs -noout | grep -Fx "${ca_subject}" || {
2347+
cat "${proxy_ca_pem}" >> /opt/conda/default/ssl/cacert.pem
2348+
}
2349+
2350+
sed -i -e 's|http://|https://|' /etc/gnupg/dirmngr.conf
2351+
export http_proxy="https://${METADATA_HTTP_PROXY}"
2352+
export https_proxy="https://${METADATA_HTTP_PROXY}"
2353+
export HTTP_PROXY="https://${METADATA_HTTP_PROXY}"
2354+
export HTTPS_PROXY="https://${METADATA_HTTP_PROXY}"
2355+
sed -i -e 's|proxy=http://|proxy=https://|' -e 's|PROXY=http://|PROXY=https://|' /etc/environment
2356+
2357+
# Instruct the JRE to trust the certificate
2358+
JAVA_HOME="$(awk -F= '/^JAVA_HOME=/ {print $2}' /etc/environment)"
2359+
"${JAVA_HOME}/bin/keytool" -import -cacerts -storepass changeit -noprompt -alias swp_ca -file "${proxy_ca_pem}"
21702360
}
21712361

21722362
function mount_ramdisk(){
@@ -2229,6 +2419,7 @@ function prepare_to_install(){
22292419
# Verify OS compatability and Secure boot state
22302420
check_os
22312421
check_secure_boot
2422+
set_proxy
22322423

22332424
# With the 402.0.0 release of gcloud sdk, `gcloud storage` can be
22342425
# used as a more performant replacement for `gsutil`
@@ -2241,8 +2432,6 @@ function prepare_to_install(){
22412432
fi
22422433
curl_retry_args="-fsSL --retry-connrefused --retry 10 --retry-max-time 30"
22432434

2244-
prepare_gpu_env
2245-
22462435
workdir=/opt/install-dpgce
22472436
tmpdir=/tmp/
22482437
temp_bucket="$(get_metadata_attribute dataproc-temp-bucket)"
@@ -2251,9 +2440,10 @@ function prepare_to_install(){
22512440
readonly bdcfg="/usr/local/bin/bdconfig"
22522441
export DEBIAN_FRONTEND=noninteractive
22532442

2443+
prepare_gpu_env
2444+
22542445
mkdir -p "${workdir}/complete"
22552446
trap exit_handler EXIT
2256-
set_proxy
22572447
mount_ramdisk
22582448

22592449
readonly install_log="${tmpdir}/install.log"
@@ -2391,24 +2581,27 @@ function install_spark_rapids() {
23912581

23922582
# Update SPARK RAPIDS config
23932583
local DEFAULT_SPARK_RAPIDS_VERSION
2584+
local nvidia_repo_url
23942585
DEFAULT_SPARK_RAPIDS_VERSION="24.08.1"
2395-
if version_ge "${DATAPROC_IMAGE_VERSION}" "2.2" ; then
2396-
DEFAULT_SPARK_RAPIDS_VERSION="25.02.1"
2586+
if [[ "${DATAPROC_IMAGE_VERSION}" == "2.0" ]] ; then
2587+
DEFAULT_SPARK_RAPIDS_VERSION="23.08.2" # Final release to support spark 3.1.3
2588+
nvidia_repo_url='https://repo1.maven.org/maven2/com/nvidia'
2589+
elif version_ge "${DATAPROC_IMAGE_VERSION}" "2.2" ; then
2590+
DEFAULT_SPARK_RAPIDS_VERSION="25.08.0"
2591+
nvidia_repo_url='https://edge.urm.nvidia.com/artifactory/sw-spark-maven/com/nvidia'
2592+
elif version_ge "${DATAPROC_IMAGE_VERSION}" "2.1" ; then
2593+
DEFAULT_SPARK_RAPIDS_VERSION="25.08.0"
2594+
nvidia_repo_url='https://edge.urm.nvidia.com/artifactory/sw-spark-maven/com/nvidia'
23972595
fi
23982596
local DEFAULT_XGBOOST_VERSION="1.7.6" # 2.1.3
23992597

24002598
# https://mvnrepository.com/artifact/ml.dmlc/xgboost4j-spark-gpu
24012599
local -r scala_ver="2.12"
24022600

2403-
if [[ "${DATAPROC_IMAGE_VERSION}" == "2.0" ]] ; then
2404-
DEFAULT_SPARK_RAPIDS_VERSION="23.08.2" # Final release to support spark 3.1.3
2405-
fi
2406-
24072601
readonly SPARK_RAPIDS_VERSION=$(get_metadata_attribute 'spark-rapids-version' ${DEFAULT_SPARK_RAPIDS_VERSION})
24082602
readonly XGBOOST_VERSION=$(get_metadata_attribute 'xgboost-version' ${DEFAULT_XGBOOST_VERSION})
24092603

24102604
local -r rapids_repo_url='https://repo1.maven.org/maven2/ai/rapids'
2411-
local -r nvidia_repo_url='https://repo1.maven.org/maven2/com/nvidia'
24122605
local -r dmlc_repo_url='https://repo.maven.apache.org/maven2/ml/dmlc'
24132606

24142607
local jar_basename

0 commit comments

Comments
 (0)