Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 52 additions & 19 deletions spark-rapids/spark-rapids.sh
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ function execute_with_retries() {
return 1
}

function install_spark_rapids() {
local -r nvidia_repo_url='https://repo1.maven.org/maven2/com/nvidia'
function install_gpu_xgboost() {
local -r dmlc_repo_url='https://repo.maven.apache.org/maven2/ml/dmlc'

wget -nv --timeout=30 --tries=5 --retry-connrefused \
Expand All @@ -279,6 +278,29 @@ function install_spark_rapids() {
wget -nv --timeout=30 --tries=5 --retry-connrefused \
"${dmlc_repo_url}/xgboost4j-gpu_2.12/${XGBOOST_VERSION}/xgboost4j-gpu_2.12-${XGBOOST_VERSION}.jar" \
-P /usr/lib/spark/jars/
}

function check_spark_rapids_jar() {
local jars_found
jars_found=$(ls /usr/lib/spark/jars/rapids-4-spark_*.jar 2>/dev/null | wc -l)
if [[ $jars_found -gt 0 ]]; then
echo "RAPIDS Spark plugin JAR found"
return 0
else
echo "RAPIDS Spark plugin JAR not found"
return 1
fi
Comment on lines +284 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using ls | wc -l to check for file existence is not fully robust. For instance, it can fail in unexpected ways if file names contain newlines (though unlikely here) and it suppresses all errors from ls. A more robust and idiomatic way to check if any files match a glob pattern in bash is to use compgen -G.

Suggested change
local jars_found
jars_found=$(ls /usr/lib/spark/jars/rapids-4-spark_*.jar 2>/dev/null | wc -l)
if [[ $jars_found -gt 0 ]]; then
echo "RAPIDS Spark plugin JAR found"
return 0
else
echo "RAPIDS Spark plugin JAR not found"
return 1
fi
if compgen -G "/usr/lib/spark/jars/rapids-4-spark_*.jar" > /dev/null; then
echo "RAPIDS Spark plugin JAR found"
return 0
else
echo "RAPIDS Spark plugin JAR not found"
return 1
fi

}

function remove_spark_rapids_jar() {
rm -f /usr/lib/spark/jars/rapids-4-spark_*.jar
echo "Existing RAPIDS Spark plugin JAR removed successfully"
}

function install_spark_rapids() {

local -r nvidia_repo_url='https://repo1.maven.org/maven2/com/nvidia'

wget -nv --timeout=30 --tries=5 --retry-connrefused \
"${nvidia_repo_url}/rapids-4-spark_2.12/${SPARK_RAPIDS_VERSION}/rapids-4-spark_2.12-${SPARK_RAPIDS_VERSION}.jar" \
-P /usr/lib/spark/jars/
Expand Down Expand Up @@ -807,27 +829,38 @@ function remove_old_backports {


function main() {
if is_debian && [[ $(echo "${DATAPROC_IMAGE_VERSION} <= 2.1" | bc -l) == 1 ]]; then
remove_old_backports
fi
check_os_and_secure_boot
setup_gpu_yarn
if [[ "${RUNTIME}" == "SPARK" ]]; then
# If the RAPIDS Spark RAPIDS JAR is already installed (common on ML images), replace it with the requested version
# ML images by default have Spark RAPIDS and GPU drivers installed
if check_spark_rapids_jar; then
# This ensures the cluster always uses the desired RAPIDS version, even if a default is present
remove_spark_rapids_jar
install_spark_rapids
configure_spark
echo "RAPIDS initialized with Spark runtime"
echo "RAPIDS Spark RAPIDS JAR replaced successfully"
else
echo "Unsupported RAPIDS Runtime: ${RUNTIME}"
exit 1
fi
# Install GPU drivers and setup SPARK RAPIDS JAR for non-ML images
if is_debian && [[ $(echo "${DATAPROC_IMAGE_VERSION} <= 2.1" | bc -l) == 1 ]]; then
remove_old_backports
fi
check_os_and_secure_boot
setup_gpu_yarn
if [[ "${RUNTIME}" == "SPARK" ]]; then
install_spark_rapids
install_gpu_xgboost
configure_spark
echo "RAPIDS initialized with Spark runtime"
else
echo "Unsupported RAPIDS Runtime: ${RUNTIME}"
exit 1
fi

for svc in resourcemanager nodemanager; do
if [[ $(systemctl show hadoop-yarn-${svc}.service -p SubState --value) == 'running' ]]; then
systemctl restart hadoop-yarn-${svc}.service
for svc in resourcemanager nodemanager; do
if [[ $(systemctl show hadoop-yarn-${svc}.service -p SubState --value) == 'running' ]]; then
systemctl restart hadoop-yarn-${svc}.service
fi
done
if is_debian || is_ubuntu ; then
apt-get clean
fi
done
if is_debian || is_ubuntu ; then
apt-get clean
fi
}
Comment on lines 831 to 865
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The YARN services (resourcemanager, nodemanager) need to be restarted for Spark to pick up the newly installed/updated RAPIDS JAR. In the current implementation, this restart only happens in the else block (for non-ML images). When a JAR is updated on an ML image (the if block), the services are not restarted, which means the changes won't take effect until a manual restart. The service restart loop should be moved outside the if/else block to ensure it runs in both scenarios.

function main() {
  # If the RAPIDS Spark RAPIDS JAR is already installed (common on ML images), replace it with the requested version
  # ML images by default have Spark RAPIDS and GPU drivers installed
  if check_spark_rapids_jar; then
    # This ensures the cluster always uses the desired RAPIDS version, even if a default is present
    remove_spark_rapids_jar
    install_spark_rapids
    echo "RAPIDS Spark RAPIDS JAR replaced successfully"
  else
    # Install GPU drivers and setup SPARK RAPIDS JAR for non-ML images
    if is_debian && [[ $(echo "${DATAPROC_IMAGE_VERSION} <= 2.1" | bc -l) == 1 ]]; then
      remove_old_backports
    fi
    check_os_and_secure_boot
    setup_gpu_yarn
    if [[ "${RUNTIME}" == "SPARK" ]]; then
      install_spark_rapids
      install_gpu_xgboost
      configure_spark
      echo "RAPIDS initialized with Spark runtime"
    else
      echo "Unsupported RAPIDS Runtime: ${RUNTIME}"
      exit 1
    fi

    if is_debian || is_ubuntu ; then
      apt-get clean
    fi
  fi

  for svc in resourcemanager nodemanager; do
    if [[ $(systemctl show hadoop-yarn-${svc}.service -p SubState --value) == 'running' ]]; then
      systemctl restart hadoop-yarn-${svc}.service
    fi
  done
}


Expand Down