Skip to content

Commit 4d387cb

Browse files
trivialfiswbo4958
andauthored
[backport] [pyspark] rework transform to reuse same code (dmlc#9292) (dmlc#9558)
Co-authored-by: Bobby Wang <[email protected]>
1 parent 3fde936 commit 4d387cb

File tree

1 file changed

+123
-129
lines changed
  • python-package/xgboost/spark

1 file changed

+123
-129
lines changed

python-package/xgboost/spark/core.py

Lines changed: 123 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
6565
from xgboost.training import train as worker_train
6666

67+
from .._typing import ArrayLike
6768
from .data import (
6869
_read_csr_matrix_from_unwrapped_spark_vec,
6970
alias,
@@ -1117,12 +1118,86 @@ def _get_feature_col(
11171118
)
11181119
return features_col, feature_col_names
11191120

1121+
def _get_pred_contrib_col_name(self) -> Optional[str]:
1122+
"""Return the pred_contrib_col col name"""
1123+
pred_contrib_col_name = None
1124+
if (
1125+
self.isDefined(self.pred_contrib_col)
1126+
and self.getOrDefault(self.pred_contrib_col) != ""
1127+
):
1128+
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
1129+
1130+
return pred_contrib_col_name
1131+
1132+
def _out_schema(self) -> Tuple[bool, str]:
1133+
"""Return the bool to indicate if it's a single prediction, true is single prediction,
1134+
and the returned type of the user-defined function. The value must
1135+
be a DDL-formatted type string."""
1136+
1137+
if self._get_pred_contrib_col_name() is not None:
1138+
return False, f"{pred.prediction} double, {pred.pred_contrib} array<double>"
1139+
1140+
return True, "double"
1141+
1142+
def _get_predict_func(self) -> Callable:
1143+
"""Return the true prediction function which will be running on the executor side"""
1144+
1145+
predict_params = self._gen_predict_params_dict()
1146+
pred_contrib_col_name = self._get_pred_contrib_col_name()
1147+
1148+
def _predict(
1149+
model: XGBModel, X: ArrayLike, base_margin: Optional[ArrayLike]
1150+
) -> Union[pd.DataFrame, pd.Series]:
1151+
data = {}
1152+
preds = model.predict(
1153+
X,
1154+
base_margin=base_margin,
1155+
validate_features=False,
1156+
**predict_params,
1157+
)
1158+
data[pred.prediction] = pd.Series(preds)
1159+
1160+
if pred_contrib_col_name is not None:
1161+
contribs = pred_contribs(model, X, base_margin)
1162+
data[pred.pred_contrib] = pd.Series(list(contribs))
1163+
return pd.DataFrame(data=data)
1164+
1165+
return data[pred.prediction]
1166+
1167+
return _predict
1168+
1169+
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
1170+
"""Post process of transform"""
1171+
prediction_col_name = self.getOrDefault(self.predictionCol)
1172+
single_pred, _ = self._out_schema()
1173+
1174+
if single_pred:
1175+
if prediction_col_name:
1176+
dataset = dataset.withColumn(prediction_col_name, pred_col)
1177+
else:
1178+
pred_struct_col = "_prediction_struct"
1179+
dataset = dataset.withColumn(pred_struct_col, pred_col)
1180+
1181+
if prediction_col_name:
1182+
dataset = dataset.withColumn(
1183+
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
1184+
)
1185+
1186+
pred_contrib_col_name = self._get_pred_contrib_col_name()
1187+
if pred_contrib_col_name is not None:
1188+
dataset = dataset.withColumn(
1189+
pred_contrib_col_name,
1190+
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
1191+
)
1192+
1193+
dataset = dataset.drop(pred_struct_col)
1194+
return dataset
1195+
11201196
def _transform(self, dataset: DataFrame) -> DataFrame:
11211197
# pylint: disable=too-many-statements, too-many-locals
11221198
# Save xgb_sklearn_model and predict_params to be local variable
11231199
# to avoid the `self` object to be pickled to remote.
11241200
xgb_sklearn_model = self._xgb_sklearn_model
1125-
predict_params = self._gen_predict_params_dict()
11261201

11271202
has_base_margin = False
11281203
if (
@@ -1137,18 +1212,9 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
11371212
features_col, feature_col_names = self._get_feature_col(dataset)
11381213
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
11391214

1140-
pred_contrib_col_name = None
1141-
if (
1142-
self.isDefined(self.pred_contrib_col)
1143-
and self.getOrDefault(self.pred_contrib_col) != ""
1144-
):
1145-
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
1215+
predict_func = self._get_predict_func()
11461216

1147-
single_pred = True
1148-
schema = "double"
1149-
if pred_contrib_col_name:
1150-
single_pred = False
1151-
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
1217+
_, schema = self._out_schema()
11521218

11531219
@pandas_udf(schema) # type: ignore
11541220
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
@@ -1168,48 +1234,14 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
11681234
else:
11691235
base_margin = None
11701236

1171-
data = {}
1172-
preds = model.predict(
1173-
X,
1174-
base_margin=base_margin,
1175-
validate_features=False,
1176-
**predict_params,
1177-
)
1178-
data[pred.prediction] = pd.Series(preds)
1179-
1180-
if pred_contrib_col_name:
1181-
contribs = pred_contribs(model, X, base_margin)
1182-
data[pred.pred_contrib] = pd.Series(list(contribs))
1183-
yield pd.DataFrame(data=data)
1184-
else:
1185-
yield data[pred.prediction]
1237+
yield predict_func(model, X, base_margin)
11861238

11871239
if has_base_margin:
11881240
pred_col = predict_udf(struct(*features_col, base_margin_col))
11891241
else:
11901242
pred_col = predict_udf(struct(*features_col))
11911243

1192-
prediction_col_name = self.getOrDefault(self.predictionCol)
1193-
1194-
if single_pred:
1195-
dataset = dataset.withColumn(prediction_col_name, pred_col)
1196-
else:
1197-
pred_struct_col = "_prediction_struct"
1198-
dataset = dataset.withColumn(pred_struct_col, pred_col)
1199-
1200-
dataset = dataset.withColumn(
1201-
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
1202-
)
1203-
1204-
if pred_contrib_col_name:
1205-
dataset = dataset.withColumn(
1206-
pred_contrib_col_name,
1207-
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
1208-
)
1209-
1210-
dataset = dataset.drop(pred_struct_col)
1211-
1212-
return dataset
1244+
return self._post_transform(dataset, pred_col)
12131245

12141246

12151247
class _ClassificationModel( # pylint: disable=abstract-method
@@ -1221,22 +1253,21 @@ class _ClassificationModel( # pylint: disable=abstract-method
12211253
.. Note:: This API is experimental.
12221254
"""
12231255

1224-
def _transform(self, dataset: DataFrame) -> DataFrame:
1225-
# pylint: disable=too-many-statements, too-many-locals
1226-
# Save xgb_sklearn_model and predict_params to be local variable
1227-
# to avoid the `self` object to be pickled to remote.
1228-
xgb_sklearn_model = self._xgb_sklearn_model
1229-
predict_params = self._gen_predict_params_dict()
1256+
def _out_schema(self) -> Tuple[bool, str]:
1257+
schema = (
1258+
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
1259+
f" {pred.probability} array<double>"
1260+
)
1261+
if self._get_pred_contrib_col_name() is not None:
1262+
# We will force setting strict_shape to True when predicting contribs,
1263+
# So, it will also output 3-D shape result.
1264+
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
12301265

1231-
has_base_margin = False
1232-
if (
1233-
self.isDefined(self.base_margin_col)
1234-
and self.getOrDefault(self.base_margin_col) != ""
1235-
):
1236-
has_base_margin = True
1237-
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
1238-
alias.margin
1239-
)
1266+
return False, schema
1267+
1268+
def _get_predict_func(self) -> Callable:
1269+
predict_params = self._gen_predict_params_dict()
1270+
pred_contrib_col_name = self._get_pred_contrib_col_name()
12401271

12411272
def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
12421273
if margins.ndim == 1:
@@ -1251,76 +1282,38 @@ def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
12511282
class_probs = softmax(raw_preds, axis=1)
12521283
return raw_preds, class_probs
12531284

1254-
features_col, feature_col_names = self._get_feature_col(dataset)
1255-
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
1256-
1257-
pred_contrib_col_name = None
1258-
if (
1259-
self.isDefined(self.pred_contrib_col)
1260-
and self.getOrDefault(self.pred_contrib_col) != ""
1261-
):
1262-
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
1263-
1264-
schema = (
1265-
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
1266-
f" {pred.probability} array<double>"
1267-
)
1268-
if pred_contrib_col_name:
1269-
# We will force setting strict_shape to True when predicting contribs,
1270-
# So, it will also output 3-D shape result.
1271-
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
1272-
1273-
@pandas_udf(schema) # type: ignore
1274-
def predict_udf(
1275-
iterator: Iterator[Tuple[pd.Series, ...]]
1276-
) -> Iterator[pd.DataFrame]:
1277-
assert xgb_sklearn_model is not None
1278-
model = xgb_sklearn_model
1279-
for data in iterator:
1280-
if enable_sparse_data_optim:
1281-
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
1282-
else:
1283-
if feature_col_names is not None:
1284-
X = data[feature_col_names] # type: ignore
1285-
else:
1286-
X = stack_series(data[alias.data])
1287-
1288-
if has_base_margin:
1289-
base_margin = stack_series(data[alias.margin])
1290-
else:
1291-
base_margin = None
1292-
1293-
margins = model.predict(
1294-
X,
1295-
base_margin=base_margin,
1296-
output_margin=True,
1297-
validate_features=False,
1298-
**predict_params,
1299-
)
1300-
raw_preds, class_probs = transform_margin(margins)
1301-
1302-
# It seems that they use argmax of class probs,
1303-
# not of margin to get the prediction (Note: scala implementation)
1304-
preds = np.argmax(class_probs, axis=1)
1305-
result: Dict[str, pd.Series] = {
1306-
pred.raw_prediction: pd.Series(list(raw_preds)),
1307-
pred.prediction: pd.Series(preds),
1308-
pred.probability: pd.Series(list(class_probs)),
1309-
}
1285+
def _predict(
1286+
model: XGBModel, X: ArrayLike, base_margin: Optional[np.ndarray]
1287+
) -> Union[pd.DataFrame, pd.Series]:
1288+
margins = model.predict(
1289+
X,
1290+
base_margin=base_margin,
1291+
output_margin=True,
1292+
validate_features=False,
1293+
**predict_params,
1294+
)
1295+
raw_preds, class_probs = transform_margin(margins)
1296+
1297+
# It seems that they use argmax of class probs,
1298+
# not of margin to get the prediction (Note: scala implementation)
1299+
preds = np.argmax(class_probs, axis=1)
1300+
result: Dict[str, pd.Series] = {
1301+
pred.raw_prediction: pd.Series(list(raw_preds)),
1302+
pred.prediction: pd.Series(preds),
1303+
pred.probability: pd.Series(list(class_probs)),
1304+
}
13101305

1311-
if pred_contrib_col_name:
1312-
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
1313-
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
1306+
if pred_contrib_col_name is not None:
1307+
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
1308+
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
13141309

1315-
yield pd.DataFrame(data=result)
1310+
return pd.DataFrame(data=result)
13161311

1317-
if has_base_margin:
1318-
pred_struct = predict_udf(struct(*features_col, base_margin_col))
1319-
else:
1320-
pred_struct = predict_udf(struct(*features_col))
1312+
return _predict
13211313

1314+
def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
13221315
pred_struct_col = "_prediction_struct"
1323-
dataset = dataset.withColumn(pred_struct_col, pred_struct)
1316+
dataset = dataset.withColumn(pred_struct_col, pred_col)
13241317

13251318
raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
13261319
if raw_prediction_col_name:
@@ -1342,7 +1335,8 @@ def predict_udf(
13421335
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
13431336
)
13441337

1345-
if pred_contrib_col_name:
1338+
pred_contrib_col_name = self._get_pred_contrib_col_name()
1339+
if pred_contrib_col_name is not None:
13461340
dataset = dataset.withColumn(
13471341
pred_contrib_col_name,
13481342
getattr(col(pred_struct_col), pred.pred_contrib),

0 commit comments

Comments
 (0)