File tree Expand file tree Collapse file tree 2 files changed +17
-6
lines changed
python-package/xgboost/spark
tests/test_distributed/test_with_spark Expand file tree Collapse file tree 2 files changed +17
-6
lines changed Original file line number Diff line number Diff line change @@ -449,12 +449,6 @@ def _validate_params(self) -> None:
449
449
"The `exact` tree method is not supported for distributed systems."
450
450
)
451
451
452
- if self .getOrDefault (self .features_cols ):
453
- if not self ._run_on_gpu ():
454
- raise ValueError (
455
- "features_col param with list value requires `device=cuda`."
456
- )
457
-
458
452
if self .getOrDefault ("objective" ) is not None :
459
453
if not isinstance (self .getOrDefault ("objective" ), str ):
460
454
raise ValueError ("Only string type 'objective' param is allowed." )
Original file line number Diff line number Diff line change @@ -1796,6 +1796,23 @@ def check_conf(conf: Config) -> None:
1796
1796
loaded_model = SparkXGBClassifierModel .load (path )
1797
1797
check_conf (loaded_model .getOrDefault (loaded_model .coll_cfg ))
1798
1798
1799
+ def test_classifier_with_multi_cols (self ):
1800
+ df = self .session .createDataFrame (
1801
+ [
1802
+ (1.0 , 2.0 , 0 ),
1803
+ (3.1 , 4.2 , 1 ),
1804
+ ],
1805
+ ["a" , "b" , "label" ],
1806
+ )
1807
+ features = ["a" , "b" ]
1808
+ cls = SparkXGBClassifier (features_col = features , device = "cpu" , n_estimators = 2 )
1809
+ model = cls .fit (df )
1810
+ self .assertEqual (features , model .getOrDefault (model .features_cols ))
1811
+ self .assertTrue (not model .isSet (model .featuresCol ))
1812
+
1813
+ # No exception
1814
+ model .transform (df ).collect ()
1815
+
1799
1816
1800
1817
LTRData = namedtuple (
1801
1818
"LTRData" ,
You can’t perform that action at this time.
0 commit comments