Skip to content

Commit 76f95a6

Browse files
authored
[pyspark] Filter out the unsupported train parameters (dmlc#8355)
1 parent 3901f5d commit 76f95a6

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

python-package/xgboost/spark/core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@
126126
"eval_qid", # Use spark param `qid_col` instead
127127
}
128128

129+
_unsupported_train_params = {
130+
"evals", # Supported by spark param validation_indicator_col
131+
"evals_result", # Won't support yet+
132+
}
133+
129134
_unsupported_predict_params = {
130135
# for classification, we can use rawPrediction as margin
131136
"output_margin",
@@ -515,6 +520,7 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name
515520
k in _unsupported_xgb_params
516521
or k in _unsupported_fit_params
517522
or k in _unsupported_predict_params
523+
or k in _unsupported_train_params
518524
):
519525
raise ValueError(f"Unsupported param '{k}'.")
520526
_extra_params[k] = v
@@ -620,7 +626,9 @@ def _get_distributed_train_params(self, dataset):
620626

621627
@classmethod
622628
def _get_xgb_train_call_args(cls, train_params):
623-
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
629+
xgb_train_default_args = _get_default_params_from_func(
630+
xgboost.train, _unsupported_train_params
631+
)
624632
booster_params, kwargs_params = {}, {}
625633
for key, value in train_params.items():
626634
if key in xgb_train_default_args:

tests/python/test_spark/test_spark_local.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,3 +1126,7 @@ def test_early_stop_param_validation(self):
11261126
classifier = SparkXGBClassifier(early_stopping_rounds=1)
11271127
with pytest.raises(ValueError, match="early_stopping_rounds"):
11281128
classifier.fit(self.cls_df_train)
1129+
1130+
def test_unsupported_params(self):
1131+
with pytest.raises(ValueError, match="evals_result"):
1132+
SparkXGBClassifier(evals_result={})

0 commit comments

Comments
 (0)