File tree Expand file tree Collapse file tree 2 files changed +17
-6
lines changed
python-package/xgboost/spark
tests/test_distributed/test_with_spark Expand file tree Collapse file tree 2 files changed +17
-6
lines changed Original file line number Diff line number Diff line change @@ -449,12 +449,6 @@ def _validate_params(self) -> None:
449449 "The `exact` tree method is not supported for distributed systems."
450450 )
451451
452- if self .getOrDefault (self .features_cols ):
453- if not self ._run_on_gpu ():
454- raise ValueError (
455- "features_col param with list value requires `device=cuda`."
456- )
457-
458452 if self .getOrDefault ("objective" ) is not None :
459453 if not isinstance (self .getOrDefault ("objective" ), str ):
460454 raise ValueError ("Only string type 'objective' param is allowed." )
Original file line number Diff line number Diff line change @@ -1796,6 +1796,23 @@ def check_conf(conf: Config) -> None:
17961796 loaded_model = SparkXGBClassifierModel .load (path )
17971797 check_conf (loaded_model .getOrDefault (loaded_model .coll_cfg ))
17981798
1799+ def test_classifier_with_multi_cols (self ):
1800+ df = self .session .createDataFrame (
1801+ [
1802+ (1.0 , 2.0 , 0 ),
1803+ (3.1 , 4.2 , 1 ),
1804+ ],
1805+ ["a" , "b" , "label" ],
1806+ )
1807+ features = ["a" , "b" ]
1808+ cls = SparkXGBClassifier (features_col = features , device = "cpu" , n_estimators = 2 )
1809+ model = cls .fit (df )
1810+ self .assertEqual (features , model .getOrDefault (model .features_cols ))
1811+ self .assertTrue (not model .isSet (model .featuresCol ))
1812+
1813+ # No exception
1814+ model .transform (df ).collect ()
1815+
17991816
18001817LTRData = namedtuple (
18011818 "LTRData" ,
You can’t perform that action at this time.
0 commit comments