Skip to content

Commit 59c54e3

Browse files
[pyspark] Make QDM optional based on cuDF check (dmlc#8471) (dmlc#8556)
Co-authored-by: WeichenXu <[email protected]>
1 parent 60a8c8e commit 59c54e3

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

python-package/xgboost/compat.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
4343
pandas_concat = None
4444
PANDAS_INSTALLED = False
4545

46+
4647
# sklearn
4748
try:
4849
from sklearn.base import BaseEstimator as XGBModelBase
@@ -72,6 +73,22 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
7273
XGBStratifiedKFold = None
7374

7475

76+
_logger = logging.getLogger(__name__)
77+
78+
79+
def is_cudf_available() -> bool:
80+
"""Check cuDF package available or not"""
81+
if importlib.util.find_spec("cudf") is None:
82+
return False
83+
try:
84+
import cudf
85+
86+
return True
87+
except ImportError:
88+
_logger.exception("Importing cuDF failed, use DMatrix instead of QDM")
89+
return False
90+
91+
7592
class XGBoostLabelEncoder(LabelEncoder):
7693
"""Label encoder with JSON serialization methods."""
7794

python-package/xgboost/spark/core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ShortType,
3333
)
3434
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
35+
from xgboost.compat import is_cudf_available
3536
from xgboost.core import Booster
3637
from xgboost.training import train as worker_train
3738

@@ -759,7 +760,8 @@ def _fit(self, dataset):
759760
k: v for k, v in train_call_kwargs_params.items() if v is not None
760761
}
761762
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
762-
use_qdm = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
763+
764+
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
763765

764766
def _train_booster(pandas_df_iter):
765767
"""Takes in an RDD partition and outputs a booster for that partition after
@@ -773,6 +775,15 @@ def _train_booster(pandas_df_iter):
773775

774776
gpu_id = None
775777

778+
# If cuDF is not installed, then using DMatrix instead of QDM,
779+
# because without cuDF, DMatrix performs better than QDM.
780+
# Note: Checking `is_cudf_available` in spark worker side because
781+
# spark worker might has different python environment with driver side.
782+
if use_gpu:
783+
use_qdm = use_hist and is_cudf_available()
784+
else:
785+
use_qdm = use_hist
786+
776787
if use_qdm and (booster_params.get("max_bin", None) is not None):
777788
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
778789

0 commit comments

Comments
 (0)