@@ -527,7 +527,8 @@ def _validate_and_convert_feature_col_as_array_col(
527
527
(DoubleType , FloatType , LongType , IntegerType , ShortType ),
528
528
):
529
529
raise ValueError (
530
- "If feature column is array type, its elements must be number type."
530
+ "If feature column is array type, its elements must be number type, "
531
+ f"got { features_col_datatype .elementType } ."
531
532
)
532
533
features_array_col = features_col .cast (ArrayType (FloatType ())).alias (alias .data )
533
534
elif isinstance (features_col_datatype , VectorUDT ):
@@ -1379,15 +1380,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
1379
1380
# to avoid the `self` object to be pickled to remote.
1380
1381
xgb_sklearn_model = self ._xgb_sklearn_model
1381
1382
1382
- has_base_margin = False
1383
+ base_margin_col = None
1383
1384
if (
1384
1385
self .isDefined (self .base_margin_col )
1385
1386
and self .getOrDefault (self .base_margin_col ) != ""
1386
1387
):
1387
- has_base_margin = True
1388
1388
base_margin_col = col (self .getOrDefault (self .base_margin_col )).alias (
1389
1389
alias .margin
1390
1390
)
1391
+ has_base_margin = base_margin_col is not None
1391
1392
1392
1393
features_col , feature_col_names = self ._get_feature_col (dataset )
1393
1394
enable_sparse_data_optim = self .getOrDefault (self .enable_sparse_data_optim )
@@ -1472,6 +1473,7 @@ def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:
1472
1473
yield predict_func (model , X , base_margin )
1473
1474
1474
1475
if has_base_margin :
1476
+ assert base_margin_col is not None
1475
1477
pred_col = predict_udf (struct (* features_col , base_margin_col ))
1476
1478
else :
1477
1479
pred_col = predict_udf (struct (* features_col ))
0 commit comments