Skip to content

Commit 60a8c8e

Browse files
trivialfiswbo4958
andauthored
[pyspark] sort qid for SparkRanker (dmlc#8497) (dmlc#8555)
* [pyspark] sort qid for SparkRandker * resolve comments Co-authored-by: Bobby Wang <[email protected]>
1 parent 58bc225 commit 60a8c8e

File tree

2 files changed

+65
-32
lines changed

2 files changed

+65
-32
lines changed

python-package/xgboost/spark/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# type: ignore
22
"""Xgboost pyspark integration submodule for core code."""
33
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
4-
# pylint: disable=too-few-public-methods, too-many-lines
4+
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
55
import json
66
from typing import Iterator, Optional, Tuple
77

@@ -728,6 +728,10 @@ def _fit(self, dataset):
728728
else:
729729
dataset = dataset.repartition(num_workers)
730730

731+
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
732+
# XGBoost requires qid to be sorted for each partition
733+
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
734+
731735
train_params = self._get_distributed_train_params(dataset)
732736
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
733737
train_params

tests/python/test_spark/test_spark_local.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -390,28 +390,6 @@ def setUp(self):
390390
"expected_prediction_with_base_margin",
391391
],
392392
)
393-
self.ranker_df_train = self.session.createDataFrame(
394-
[
395-
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
396-
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
397-
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
398-
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
399-
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
400-
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
401-
],
402-
["features", "label", "qid"],
403-
)
404-
self.ranker_df_test = self.session.createDataFrame(
405-
[
406-
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
407-
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
408-
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
409-
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
410-
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
411-
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
412-
],
413-
["features", "qid", "expected_prediction"],
414-
)
415393

416394
self.reg_df_sparse_train = self.session.createDataFrame(
417395
[
@@ -1039,15 +1017,6 @@ def test_classifier_with_sparse_optim(self):
10391017
for row1, row2 in zip(pred_result, pred_result2):
10401018
self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3))
10411019

1042-
def test_ranker(self):
1043-
ranker = SparkXGBRanker(qid_col="qid")
1044-
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
1045-
model = ranker.fit(self.ranker_df_train)
1046-
pred_result = model.transform(self.ranker_df_test).collect()
1047-
1048-
for row in pred_result:
1049-
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
1050-
10511020
def test_empty_validation_data(self) -> None:
10521021
for tree_method in [
10531022
"hist",
@@ -1130,3 +1099,63 @@ def test_early_stop_param_validation(self):
11301099
def test_unsupported_params(self):
11311100
with pytest.raises(ValueError, match="evals_result"):
11321101
SparkXGBClassifier(evals_result={})
1102+
1103+
1104+
class XgboostRankerLocalTest(SparkTestCase):
1105+
def setUp(self):
1106+
self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
1107+
self.ranker_df_train = self.session.createDataFrame(
1108+
[
1109+
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
1110+
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
1111+
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
1112+
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
1113+
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
1114+
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
1115+
],
1116+
["features", "label", "qid"],
1117+
)
1118+
self.ranker_df_test = self.session.createDataFrame(
1119+
[
1120+
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
1121+
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
1122+
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
1123+
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
1124+
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
1125+
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
1126+
],
1127+
["features", "qid", "expected_prediction"],
1128+
)
1129+
self.ranker_df_train_1 = self.session.createDataFrame(
1130+
[
1131+
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
1132+
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
1133+
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9),
1134+
(Vectors.dense(1.0, 2.0, 3.0), 0, 8),
1135+
(Vectors.dense(4.0, 5.0, 6.0), 1, 8),
1136+
(Vectors.dense(9.0, 4.0, 8.0), 2, 8),
1137+
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7),
1138+
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7),
1139+
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7),
1140+
(Vectors.dense(1.0, 2.0, 3.0), 0, 6),
1141+
(Vectors.dense(4.0, 5.0, 6.0), 1, 6),
1142+
(Vectors.dense(9.0, 4.0, 8.0), 2, 6),
1143+
]
1144+
* 4,
1145+
["features", "label", "qid"],
1146+
)
1147+
1148+
def test_ranker(self):
1149+
ranker = SparkXGBRanker(qid_col="qid")
1150+
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
1151+
model = ranker.fit(self.ranker_df_train)
1152+
pred_result = model.transform(self.ranker_df_test).collect()
1153+
1154+
for row in pred_result:
1155+
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
1156+
1157+
def test_ranker_qid_sorted(self):
1158+
ranker = SparkXGBRanker(qid_col="qid", num_workers=4)
1159+
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
1160+
model = ranker.fit(self.ranker_df_train_1)
1161+
model.transform(self.ranker_df_test).collect()

0 commit comments

Comments
 (0)