Skip to content

Commit 532318d

Browse files
trivialfiswbo4958
andauthored
[backport][pyspark] Support columnar input for cpu pipeline (dmlc#11299) (dmlc#11301)
Co-authored-by: Bobby Wang <[email protected]>
1 parent 3adb12d commit 532318d

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

python-package/xgboost/spark/core.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,6 @@ def _validate_params(self) -> None:
449449
"The `exact` tree method is not supported for distributed systems."
450450
)
451451

452-
if self.getOrDefault(self.features_cols):
453-
if not self._run_on_gpu():
454-
raise ValueError(
455-
"features_col param with list value requires `device=cuda`."
456-
)
457-
458452
if self.getOrDefault("objective") is not None:
459453
if not isinstance(self.getOrDefault("objective"), str):
460454
raise ValueError("Only string type 'objective' param is allowed.")

tests/test_distributed/test_with_spark/test_spark_local.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,6 +1796,23 @@ def check_conf(conf: Config) -> None:
17961796
loaded_model = SparkXGBClassifierModel.load(path)
17971797
check_conf(loaded_model.getOrDefault(loaded_model.coll_cfg))
17981798

1799+
def test_classifier_with_multi_cols(self):
1800+
df = self.session.createDataFrame(
1801+
[
1802+
(1.0, 2.0, 0),
1803+
(3.1, 4.2, 1),
1804+
],
1805+
["a", "b", "label"],
1806+
)
1807+
features = ["a", "b"]
1808+
cls = SparkXGBClassifier(features_col=features, device="cpu", n_estimators=2)
1809+
model = cls.fit(df)
1810+
self.assertEqual(features, model.getOrDefault(model.features_cols))
1811+
self.assertTrue(not model.isSet(model.featuresCol))
1812+
1813+
# No exception
1814+
model.transform(df).collect()
1815+
17991816

18001817
LTRData = namedtuple(
18011818
"LTRData",

0 commit comments

Comments
 (0)