64
64
from xgboost .sklearn import DEFAULT_N_ESTIMATORS , XGBModel , _can_use_qdm
65
65
from xgboost .training import train as worker_train
66
66
67
+ from .._typing import ArrayLike
67
68
from .data import (
68
69
_read_csr_matrix_from_unwrapped_spark_vec ,
69
70
alias ,
@@ -1117,12 +1118,86 @@ def _get_feature_col(
1117
1118
)
1118
1119
return features_col , feature_col_names
1119
1120
1121
+ def _get_pred_contrib_col_name (self ) -> Optional [str ]:
1122
+ """Return the pred_contrib_col col name"""
1123
+ pred_contrib_col_name = None
1124
+ if (
1125
+ self .isDefined (self .pred_contrib_col )
1126
+ and self .getOrDefault (self .pred_contrib_col ) != ""
1127
+ ):
1128
+ pred_contrib_col_name = self .getOrDefault (self .pred_contrib_col )
1129
+
1130
+ return pred_contrib_col_name
1131
+
1132
+ def _out_schema (self ) -> Tuple [bool , str ]:
1133
+ """Return the bool to indicate if it's a single prediction, true is single prediction,
1134
+ and the returned type of the user-defined function. The value must
1135
+ be a DDL-formatted type string."""
1136
+
1137
+ if self ._get_pred_contrib_col_name () is not None :
1138
+ return False , f"{ pred .prediction } double, { pred .pred_contrib } array<double>"
1139
+
1140
+ return True , "double"
1141
+
1142
+ def _get_predict_func (self ) -> Callable :
1143
+ """Return the true prediction function which will be running on the executor side"""
1144
+
1145
+ predict_params = self ._gen_predict_params_dict ()
1146
+ pred_contrib_col_name = self ._get_pred_contrib_col_name ()
1147
+
1148
+ def _predict (
1149
+ model : XGBModel , X : ArrayLike , base_margin : Optional [ArrayLike ]
1150
+ ) -> Union [pd .DataFrame , pd .Series ]:
1151
+ data = {}
1152
+ preds = model .predict (
1153
+ X ,
1154
+ base_margin = base_margin ,
1155
+ validate_features = False ,
1156
+ ** predict_params ,
1157
+ )
1158
+ data [pred .prediction ] = pd .Series (preds )
1159
+
1160
+ if pred_contrib_col_name is not None :
1161
+ contribs = pred_contribs (model , X , base_margin )
1162
+ data [pred .pred_contrib ] = pd .Series (list (contribs ))
1163
+ return pd .DataFrame (data = data )
1164
+
1165
+ return data [pred .prediction ]
1166
+
1167
+ return _predict
1168
+
1169
+ def _post_transform (self , dataset : DataFrame , pred_col : Column ) -> DataFrame :
1170
+ """Post process of transform"""
1171
+ prediction_col_name = self .getOrDefault (self .predictionCol )
1172
+ single_pred , _ = self ._out_schema ()
1173
+
1174
+ if single_pred :
1175
+ if prediction_col_name :
1176
+ dataset = dataset .withColumn (prediction_col_name , pred_col )
1177
+ else :
1178
+ pred_struct_col = "_prediction_struct"
1179
+ dataset = dataset .withColumn (pred_struct_col , pred_col )
1180
+
1181
+ if prediction_col_name :
1182
+ dataset = dataset .withColumn (
1183
+ prediction_col_name , getattr (col (pred_struct_col ), pred .prediction )
1184
+ )
1185
+
1186
+ pred_contrib_col_name = self ._get_pred_contrib_col_name ()
1187
+ if pred_contrib_col_name is not None :
1188
+ dataset = dataset .withColumn (
1189
+ pred_contrib_col_name ,
1190
+ array_to_vector (getattr (col (pred_struct_col ), pred .pred_contrib )),
1191
+ )
1192
+
1193
+ dataset = dataset .drop (pred_struct_col )
1194
+ return dataset
1195
+
1120
1196
def _transform (self , dataset : DataFrame ) -> DataFrame :
1121
1197
# pylint: disable=too-many-statements, too-many-locals
1122
1198
# Save xgb_sklearn_model and predict_params to be local variable
1123
1199
# to avoid the `self` object to be pickled to remote.
1124
1200
xgb_sklearn_model = self ._xgb_sklearn_model
1125
- predict_params = self ._gen_predict_params_dict ()
1126
1201
1127
1202
has_base_margin = False
1128
1203
if (
@@ -1137,18 +1212,9 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
1137
1212
features_col , feature_col_names = self ._get_feature_col (dataset )
1138
1213
enable_sparse_data_optim = self .getOrDefault (self .enable_sparse_data_optim )
1139
1214
1140
- pred_contrib_col_name = None
1141
- if (
1142
- self .isDefined (self .pred_contrib_col )
1143
- and self .getOrDefault (self .pred_contrib_col ) != ""
1144
- ):
1145
- pred_contrib_col_name = self .getOrDefault (self .pred_contrib_col )
1215
+ predict_func = self ._get_predict_func ()
1146
1216
1147
- single_pred = True
1148
- schema = "double"
1149
- if pred_contrib_col_name :
1150
- single_pred = False
1151
- schema = f"{ pred .prediction } double, { pred .pred_contrib } array<double>"
1217
+ _ , schema = self ._out_schema ()
1152
1218
1153
1219
@pandas_udf (schema ) # type: ignore
1154
1220
def predict_udf (iterator : Iterator [pd .DataFrame ]) -> Iterator [pd .Series ]:
@@ -1168,48 +1234,14 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
1168
1234
else :
1169
1235
base_margin = None
1170
1236
1171
- data = {}
1172
- preds = model .predict (
1173
- X ,
1174
- base_margin = base_margin ,
1175
- validate_features = False ,
1176
- ** predict_params ,
1177
- )
1178
- data [pred .prediction ] = pd .Series (preds )
1179
-
1180
- if pred_contrib_col_name :
1181
- contribs = pred_contribs (model , X , base_margin )
1182
- data [pred .pred_contrib ] = pd .Series (list (contribs ))
1183
- yield pd .DataFrame (data = data )
1184
- else :
1185
- yield data [pred .prediction ]
1237
+ yield predict_func (model , X , base_margin )
1186
1238
1187
1239
if has_base_margin :
1188
1240
pred_col = predict_udf (struct (* features_col , base_margin_col ))
1189
1241
else :
1190
1242
pred_col = predict_udf (struct (* features_col ))
1191
1243
1192
- prediction_col_name = self .getOrDefault (self .predictionCol )
1193
-
1194
- if single_pred :
1195
- dataset = dataset .withColumn (prediction_col_name , pred_col )
1196
- else :
1197
- pred_struct_col = "_prediction_struct"
1198
- dataset = dataset .withColumn (pred_struct_col , pred_col )
1199
-
1200
- dataset = dataset .withColumn (
1201
- prediction_col_name , getattr (col (pred_struct_col ), pred .prediction )
1202
- )
1203
-
1204
- if pred_contrib_col_name :
1205
- dataset = dataset .withColumn (
1206
- pred_contrib_col_name ,
1207
- array_to_vector (getattr (col (pred_struct_col ), pred .pred_contrib )),
1208
- )
1209
-
1210
- dataset = dataset .drop (pred_struct_col )
1211
-
1212
- return dataset
1244
+ return self ._post_transform (dataset , pred_col )
1213
1245
1214
1246
1215
1247
class _ClassificationModel ( # pylint: disable=abstract-method
@@ -1221,22 +1253,21 @@ class _ClassificationModel( # pylint: disable=abstract-method
1221
1253
.. Note:: This API is experimental.
1222
1254
"""
1223
1255
1224
- def _transform (self , dataset : DataFrame ) -> DataFrame :
1225
- # pylint: disable=too-many-statements, too-many-locals
1226
- # Save xgb_sklearn_model and predict_params to be local variable
1227
- # to avoid the `self` object to be pickled to remote.
1228
- xgb_sklearn_model = self ._xgb_sklearn_model
1229
- predict_params = self ._gen_predict_params_dict ()
1256
+ def _out_schema (self ) -> Tuple [bool , str ]:
1257
+ schema = (
1258
+ f"{ pred .raw_prediction } array<double>, { pred .prediction } double,"
1259
+ f" { pred .probability } array<double>"
1260
+ )
1261
+ if self ._get_pred_contrib_col_name () is not None :
1262
+ # We will force setting strict_shape to True when predicting contribs,
1263
+ # So, it will also output 3-D shape result.
1264
+ schema = f"{ schema } , { pred .pred_contrib } array<array<double>>"
1230
1265
1231
- has_base_margin = False
1232
- if (
1233
- self .isDefined (self .base_margin_col )
1234
- and self .getOrDefault (self .base_margin_col ) != ""
1235
- ):
1236
- has_base_margin = True
1237
- base_margin_col = col (self .getOrDefault (self .base_margin_col )).alias (
1238
- alias .margin
1239
- )
1266
+ return False , schema
1267
+
1268
+ def _get_predict_func (self ) -> Callable :
1269
+ predict_params = self ._gen_predict_params_dict ()
1270
+ pred_contrib_col_name = self ._get_pred_contrib_col_name ()
1240
1271
1241
1272
def transform_margin (margins : np .ndarray ) -> Tuple [np .ndarray , np .ndarray ]:
1242
1273
if margins .ndim == 1 :
@@ -1251,76 +1282,38 @@ def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
1251
1282
class_probs = softmax (raw_preds , axis = 1 )
1252
1283
return raw_preds , class_probs
1253
1284
1254
- features_col , feature_col_names = self ._get_feature_col (dataset )
1255
- enable_sparse_data_optim = self .getOrDefault (self .enable_sparse_data_optim )
1256
-
1257
- pred_contrib_col_name = None
1258
- if (
1259
- self .isDefined (self .pred_contrib_col )
1260
- and self .getOrDefault (self .pred_contrib_col ) != ""
1261
- ):
1262
- pred_contrib_col_name = self .getOrDefault (self .pred_contrib_col )
1263
-
1264
- schema = (
1265
- f"{ pred .raw_prediction } array<double>, { pred .prediction } double,"
1266
- f" { pred .probability } array<double>"
1267
- )
1268
- if pred_contrib_col_name :
1269
- # We will force setting strict_shape to True when predicting contribs,
1270
- # So, it will also output 3-D shape result.
1271
- schema = f"{ schema } , { pred .pred_contrib } array<array<double>>"
1272
-
1273
- @pandas_udf (schema ) # type: ignore
1274
- def predict_udf (
1275
- iterator : Iterator [Tuple [pd .Series , ...]]
1276
- ) -> Iterator [pd .DataFrame ]:
1277
- assert xgb_sklearn_model is not None
1278
- model = xgb_sklearn_model
1279
- for data in iterator :
1280
- if enable_sparse_data_optim :
1281
- X = _read_csr_matrix_from_unwrapped_spark_vec (data )
1282
- else :
1283
- if feature_col_names is not None :
1284
- X = data [feature_col_names ] # type: ignore
1285
- else :
1286
- X = stack_series (data [alias .data ])
1287
-
1288
- if has_base_margin :
1289
- base_margin = stack_series (data [alias .margin ])
1290
- else :
1291
- base_margin = None
1292
-
1293
- margins = model .predict (
1294
- X ,
1295
- base_margin = base_margin ,
1296
- output_margin = True ,
1297
- validate_features = False ,
1298
- ** predict_params ,
1299
- )
1300
- raw_preds , class_probs = transform_margin (margins )
1301
-
1302
- # It seems that they use argmax of class probs,
1303
- # not of margin to get the prediction (Note: scala implementation)
1304
- preds = np .argmax (class_probs , axis = 1 )
1305
- result : Dict [str , pd .Series ] = {
1306
- pred .raw_prediction : pd .Series (list (raw_preds )),
1307
- pred .prediction : pd .Series (preds ),
1308
- pred .probability : pd .Series (list (class_probs )),
1309
- }
1285
+ def _predict (
1286
+ model : XGBModel , X : ArrayLike , base_margin : Optional [np .ndarray ]
1287
+ ) -> Union [pd .DataFrame , pd .Series ]:
1288
+ margins = model .predict (
1289
+ X ,
1290
+ base_margin = base_margin ,
1291
+ output_margin = True ,
1292
+ validate_features = False ,
1293
+ ** predict_params ,
1294
+ )
1295
+ raw_preds , class_probs = transform_margin (margins )
1296
+
1297
+ # It seems that they use argmax of class probs,
1298
+ # not of margin to get the prediction (Note: scala implementation)
1299
+ preds = np .argmax (class_probs , axis = 1 )
1300
+ result : Dict [str , pd .Series ] = {
1301
+ pred .raw_prediction : pd .Series (list (raw_preds )),
1302
+ pred .prediction : pd .Series (preds ),
1303
+ pred .probability : pd .Series (list (class_probs )),
1304
+ }
1310
1305
1311
- if pred_contrib_col_name :
1312
- contribs = pred_contribs (model , X , base_margin , strict_shape = True )
1313
- result [pred .pred_contrib ] = pd .Series (list (contribs .tolist ()))
1306
+ if pred_contrib_col_name is not None :
1307
+ contribs = pred_contribs (model , X , base_margin , strict_shape = True )
1308
+ result [pred .pred_contrib ] = pd .Series (list (contribs .tolist ()))
1314
1309
1315
- yield pd .DataFrame (data = result )
1310
+ return pd .DataFrame (data = result )
1316
1311
1317
- if has_base_margin :
1318
- pred_struct = predict_udf (struct (* features_col , base_margin_col ))
1319
- else :
1320
- pred_struct = predict_udf (struct (* features_col ))
1312
+ return _predict
1321
1313
1314
+ def _post_transform (self , dataset : DataFrame , pred_col : Column ) -> DataFrame :
1322
1315
pred_struct_col = "_prediction_struct"
1323
- dataset = dataset .withColumn (pred_struct_col , pred_struct )
1316
+ dataset = dataset .withColumn (pred_struct_col , pred_col )
1324
1317
1325
1318
raw_prediction_col_name = self .getOrDefault (self .rawPredictionCol )
1326
1319
if raw_prediction_col_name :
@@ -1342,7 +1335,8 @@ def predict_udf(
1342
1335
array_to_vector (getattr (col (pred_struct_col ), pred .probability )),
1343
1336
)
1344
1337
1345
- if pred_contrib_col_name :
1338
+ pred_contrib_col_name = self ._get_pred_contrib_col_name ()
1339
+ if pred_contrib_col_name is not None :
1346
1340
dataset = dataset .withColumn (
1347
1341
pred_contrib_col_name ,
1348
1342
getattr (col (pred_struct_col ), pred .pred_contrib ),
0 commit comments