Skip to content

Commit d495a18

Browse files
authored
[pyspark] add logs for training (dmlc#9449)
1 parent 7f85484 commit d495a18

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

python-package/xgboost/spark/core.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -924,21 +924,17 @@ def _train_booster(
924924
# Note: Checking `is_cudf_available` in spark worker side because
925925
# spark worker might has different python environment with driver side.
926926
use_qdm = use_qdm and is_cudf_available()
927+
get_logger("XGBoost-PySpark").info(
928+
"Leveraging %s to train with QDM: %s",
929+
booster_params["device"],
930+
"on" if use_qdm else "off",
931+
)
927932

928933
if use_qdm and (booster_params.get("max_bin", None) is not None):
929934
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]
930935

931936
_rabit_args = {}
932937
if context.partitionId() == 0:
933-
get_logger("XGBoostPySpark").debug(
934-
"booster params: %s\n"
935-
"train_call_kwargs_params: %s\n"
936-
"dmatrix_kwargs: %s",
937-
booster_params,
938-
train_call_kwargs_params,
939-
dmatrix_kwargs,
940-
)
941-
942938
_rabit_args = _get_rabit_args(context, num_workers)
943939

944940
worker_message = {
@@ -995,7 +991,19 @@ def _run_job() -> Tuple[str, str]:
995991
)
996992
return ret[0], ret[1]
997993

994+
get_logger("XGBoost-PySpark").info(
995+
"Running xgboost-%s on %s workers with"
996+
"\n\tbooster params: %s"
997+
"\n\ttrain_call_kwargs_params: %s"
998+
"\n\tdmatrix_kwargs: %s",
999+
xgboost._py_version(),
1000+
num_workers,
1001+
booster_params,
1002+
train_call_kwargs_params,
1003+
dmatrix_kwargs,
1004+
)
9981005
(config, booster) = _run_job()
1006+
get_logger("XGBoost-PySpark").info("Finished xgboost training!")
9991007

10001008
result_xgb_model = self._convert_to_sklearn_model(
10011009
bytearray(booster, "utf-8"), config

python-package/xgboost/spark/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger:
104104
# If the logger is configured, skip the configure
105105
if not logger.handlers and not logging.getLogger().handlers:
106106
handler = logging.StreamHandler(sys.stderr)
107+
formatter = logging.Formatter(
108+
"%(asctime)s %(levelname)s %(name)s: %(funcName)s %(message)s"
109+
)
110+
handler.setFormatter(formatter)
107111
logger.addHandler(handler)
108112
return logger
109113

0 commit comments

Comments
 (0)