Skip to content

Commit 3901f5d

Browse files
authored
[pyspark] Cleanup data processing. (dmlc#8344)
* Enable additional combinations of ctor parameters. * Unify procedures for QuantileDMatrix and DMatrix.
1 parent 521086d commit 3901f5d

File tree

5 files changed

+68
-55
lines changed

5 files changed

+68
-55
lines changed

doc/tutorials/spark_estimator.rst

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,11 @@ generate result dataset with 3 new columns:
8383
XGBoost PySpark GPU support
8484
***************************
8585

86-
XGBoost PySpark supports GPU training and prediction. To enable GPU support, you first need
87-
to install the xgboost and cudf packages. Then you can set `use_gpu` parameter to `True`.
86+
XGBoost PySpark supports GPU training and prediction. To enable GPU support, first you
87+
need to install the XGBoost and the `cuDF <https://docs.rapids.ai/api/cudf/stable/>`_
88+
package. Then you can set `use_gpu` parameter to `True`.
8889

89-
Below tutorial will show you how to train a model with XGBoost PySpark GPU on Spark
90+
Below tutorial demonstrates how to train a model with XGBoost PySpark GPU on Spark
9091
standalone cluster.
9192

9293

@@ -138,7 +139,7 @@ in PySpark. Please refer to
138139
conda create -y -n xgboost-env -c conda-forge conda-pack python=3.9
139140
conda activate xgboost-env
140141
pip install xgboost
141-
pip install cudf
142+
conda install cudf -c rapids -c nvidia -c conda-forge
142143
conda pack -f -o xgboost-env.tar.gz
143144
144145
@@ -220,3 +221,6 @@ Below is a simple example submit command for enabling GPU acceleration:
220221
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
221222
--archives xgboost-env.tar.gz#environment \
222223
xgboost_app.py
224+
225+
When rapids plugin is enabled, both of the JVM rapids plugin and the cuDF Python are
226+
required for the acceleration.

python-package/xgboost/spark/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ def _fit(self, dataset):
747747
k: v for k, v in train_call_kwargs_params.items() if v is not None
748748
}
749749
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
750-
use_qdm = booster_params.get("tree_method") in ("hist", "gpu_hist")
750+
use_qdm = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
751751

752752
def _train_booster(pandas_df_iter):
753753
"""Takes in an RDD partition and outputs a booster for that partition after

python-package/xgboost/spark/data.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,27 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
208208

209209
def append_m(part: pd.DataFrame, name: str, is_valid: bool) -> None:
210210
nonlocal n_features
211-
if name in part.columns and part[name].shape[0] > 0:
212-
array = part[name]
213-
if name == alias.data:
211+
if name == alias.data or name in part.columns:
212+
if (
213+
name == alias.data
214+
and feature_cols is not None
215+
and part[feature_cols].shape[0] > 0 # guard against empty partition
216+
):
217+
array: Optional[np.ndarray] = part[feature_cols]
218+
elif part[name].shape[0] > 0:
219+
array = part[name]
214220
array = stack_series(array)
221+
else:
222+
array = None
223+
224+
if name == alias.data and array is not None:
215225
if n_features == 0:
216226
n_features = array.shape[1]
217227
assert n_features == array.shape[1]
218228

229+
if array is None:
230+
return
231+
219232
if is_valid:
220233
valid_data[name].append(array)
221234
else:
@@ -238,26 +251,6 @@ def append_m_sparse(part: pd.DataFrame, name: str, is_valid: bool) -> None:
238251
else:
239252
train_data[name].append(array)
240253

241-
def append_qdm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
242-
"""Preprocessing for QuantileDMatrix."""
243-
nonlocal n_features
244-
if name == alias.data or name in part.columns:
245-
if name == alias.data and feature_cols is not None:
246-
array = part[feature_cols]
247-
else:
248-
array = part[name]
249-
array = stack_series(array)
250-
251-
if name == alias.data:
252-
if n_features == 0:
253-
n_features = array.shape[1]
254-
assert n_features == array.shape[1]
255-
256-
if is_valid:
257-
valid_data[name].append(array)
258-
else:
259-
train_data[name].append(array)
260-
261254
def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix:
262255
if len(values) == 0:
263256
get_logger("XGBoostPySpark").warning(
@@ -305,13 +298,14 @@ def split_params() -> Tuple[Dict[str, Any], Dict[str, Union[int, float, bool]]]:
305298

306299
meta, params = split_params()
307300

308-
if feature_cols is not None: # rapidsai plugin
309-
assert gpu_id is not None
310-
assert use_qdm is True
311-
cache_partitions(iterator, append_qdm)
301+
if feature_cols is not None and use_qdm:
302+
cache_partitions(iterator, append_fn)
312303
dtrain: DMatrix = make_qdm(train_data, gpu_id, meta, None, params)
313-
elif use_qdm:
314-
cache_partitions(iterator, append_qdm)
304+
elif feature_cols is not None and not use_qdm:
305+
cache_partitions(iterator, append_fn)
306+
dtrain = make(train_data, kwargs)
307+
elif feature_cols is None and use_qdm:
308+
cache_partitions(iterator, append_fn)
315309
dtrain = make_qdm(train_data, gpu_id, meta, None, params)
316310
else:
317311
cache_partitions(iterator, append_fn)

tests/python-gpu/test_gpu_spark/test_data.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020

2121
@pytest.mark.skipif(**tm.no_cudf())
22-
def test_qdm_ctor() -> None:
23-
run_dmatrix_ctor(is_dqm=True, on_gpu=True)
24-
with pytest.raises(AssertionError):
25-
run_dmatrix_ctor(is_dqm=False, on_gpu=True)
22+
@pytest.mark.parametrize(
23+
"is_feature_cols,is_qdm",
24+
[(True, True), (True, False), (False, True), (False, False)],
25+
)
26+
def test_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool) -> None:
27+
run_dmatrix_ctor(is_feature_cols, is_qdm, on_gpu=True)

tests/python/test_spark/test_data.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
stack_series,
1919
)
2020

21+
from xgboost import DMatrix, QuantileDMatrix
22+
2123

2224
def test_stack() -> None:
2325
a = pd.DataFrame({"a": [[1, 2], [3, 4]]})
@@ -37,7 +39,7 @@ def test_stack() -> None:
3739
assert b.shape == (2, 1)
3840

3941

40-
def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
42+
def run_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool, on_gpu: bool) -> None:
4143
rng = np.random.default_rng(0)
4244
dfs: List[pd.DataFrame] = []
4345
n_features = 16
@@ -57,27 +59,35 @@ def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
5759
df = pd.DataFrame(
5860
{alias.label: y, alias.margin: m, alias.weight: w, alias.valid: valid}
5961
)
60-
if on_gpu:
62+
if is_feature_cols:
6163
for j in range(X.shape[1]):
6264
df[f"feat-{j}"] = pd.Series(X[:, j])
6365
else:
6466
df[alias.data] = pd.Series(list(X))
6567
dfs.append(df)
6668

6769
kwargs = {"feature_types": feature_types}
68-
if on_gpu:
69-
cols = [f"feat-{i}" for i in range(n_features)]
70-
train_Xy, valid_Xy = create_dmatrix_from_partitions(
71-
iter(dfs), cols, 0, is_dqm, kwargs, False, True
72-
)
73-
elif is_dqm:
74-
train_Xy, valid_Xy = create_dmatrix_from_partitions(
75-
iter(dfs), None, None, True, kwargs, False, True
76-
)
70+
device_id = 0 if on_gpu else None
71+
cols = [f"feat-{i}" for i in range(n_features)]
72+
feature_cols = cols if is_feature_cols else None
73+
train_Xy, valid_Xy = create_dmatrix_from_partitions(
74+
iter(dfs),
75+
feature_cols,
76+
gpu_id=device_id,
77+
use_qdm=is_qdm,
78+
kwargs=kwargs,
79+
enable_sparse_data_optim=False,
80+
has_validation_col=True,
81+
)
82+
83+
if is_qdm:
84+
assert isinstance(train_Xy, QuantileDMatrix)
85+
assert isinstance(valid_Xy, QuantileDMatrix)
7786
else:
78-
train_Xy, valid_Xy = create_dmatrix_from_partitions(
79-
iter(dfs), None, None, False, kwargs, False, True
80-
)
87+
assert not isinstance(train_Xy, QuantileDMatrix)
88+
assert isinstance(train_Xy, DMatrix)
89+
assert not isinstance(valid_Xy, QuantileDMatrix)
90+
assert isinstance(valid_Xy, DMatrix)
8191

8292
assert valid_Xy is not None
8393
assert valid_Xy.num_row() + train_Xy.num_row() == n_samples_per_batch * n_batches
@@ -109,9 +119,12 @@ def run_dmatrix_ctor(is_dqm: bool, on_gpu: bool) -> None:
109119
np.testing.assert_equal(valid_Xy.feature_types, feature_types)
110120

111121

112-
def test_dmatrix_ctor() -> None:
113-
run_dmatrix_ctor(is_dqm=False, on_gpu=False)
114-
run_dmatrix_ctor(is_dqm=True, on_gpu=False)
122+
@pytest.mark.parametrize(
123+
"is_feature_cols,is_qdm",
124+
[(True, True), (True, False), (False, True), (False, False)],
125+
)
126+
def test_dmatrix_ctor(is_feature_cols: bool, is_qdm: bool) -> None:
127+
run_dmatrix_ctor(is_feature_cols, is_qdm, on_gpu=False)
115128

116129

117130
def test_read_csr_matrix_from_unwrapped_spark_vec() -> None:

0 commit comments

Comments
 (0)