Skip to content

Commit 748d516

Browse files
authored
[pyspark] Enable running GPU tests on variable number of GPUs. (dmlc#8335)
1 parent 4633b47 commit 748d516

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed
Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
11
#!/bin/bash
22

3-
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"
3+
# This script is only made for running XGBoost tests on official CI where we have access
4+
# to a 4-GPU cluster, the discovery command is for running tests on a local machine where
5+
# the driver and the GPU worker might be the same machine for the ease of development.
6+
7+
if ! command -v nvidia-smi &> /dev/null
8+
then
9+
# default to 4 GPUs
10+
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"
11+
exit
12+
else
13+
# https://github.com/apache/spark/blob/master/examples/src/main/scripts/getGpusResources.sh
14+
ADDRS=`nvidia-smi --query-gpu=index --format=csv,noheader | sed -e ':a' -e 'N' -e'$!ba' -e 's/\n/","/g'`
15+
echo {\"name\": \"gpu\", \"addresses\":[\"$ADDRS\"]}
16+
fi

tests/python-gpu/test_gpu_spark/test_gpu_spark.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
import logging
3+
import subprocess
24
import sys
35

46
import pytest
@@ -18,8 +20,20 @@
1820
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
1921

2022
gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh"
21-
executor_gpu_amount = 4
22-
executor_cores = 4
23+
24+
25+
def get_devices():
26+
"""This works only if driver is the same machine of worker."""
27+
completed = subprocess.run(gpu_discovery_script_path, stdout=subprocess.PIPE)
28+
assert completed.returncode == 0, "Failed to execute discovery script."
29+
msg = completed.stdout.decode("utf-8")
30+
result = json.loads(msg)
31+
addresses = result["addresses"]
32+
return addresses
33+
34+
35+
executor_gpu_amount = len(get_devices())
36+
executor_cores = executor_gpu_amount
2337
num_workers = executor_gpu_amount
2438

2539

0 commit comments

Comments
 (0)