Skip to content

Commit c884b9e

Browse files
authored
Validate features for inplace predict. (dmlc#8359)
1 parent 52977f0 commit c884b9e

File tree

2 files changed

+76
-50
lines changed

2 files changed

+76
-50
lines changed

python-package/xgboost/core.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,7 @@ def c_array(
406406

407407

408408
def _prediction_output(
409-
shape: CNumericPtr,
410-
dims: c_bst_ulong,
411-
predts: CFloatPtr,
412-
is_cuda: bool
409+
shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool
413410
) -> NumpyOrCupy:
414411
arr_shape = ctypes2numpy(shape, dims.value, np.uint64)
415412
length = int(np.prod(arr_shape))
@@ -1555,7 +1552,7 @@ def __init__(
15551552
ctypes.byref(self.handle)))
15561553
for d in cache:
15571554
# Validate feature only after the feature names are saved into booster.
1558-
self._validate_features(d)
1555+
self._validate_dmatrix_features(d)
15591556

15601557
if isinstance(model_file, Booster):
15611558
assert self.handle is not None
@@ -1914,7 +1911,7 @@ def update(
19141911
"""
19151912
if not isinstance(dtrain, DMatrix):
19161913
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
1917-
self._validate_features(dtrain)
1914+
self._validate_dmatrix_features(dtrain)
19181915

19191916
if fobj is None:
19201917
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle,
@@ -1946,7 +1943,7 @@ def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None:
19461943
)
19471944
if not isinstance(dtrain, DMatrix):
19481945
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
1949-
self._validate_features(dtrain)
1946+
self._validate_dmatrix_features(dtrain)
19501947

19511948
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
19521949
c_array(ctypes.c_float, grad),
@@ -1982,7 +1979,7 @@ def eval_set(
19821979
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
19831980
if not isinstance(d[1], str):
19841981
raise TypeError(f"expected string, got {type(d[1]).__name__}")
1985-
self._validate_features(d[0])
1982+
self._validate_dmatrix_features(d[0])
19861983

19871984
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
19881985
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
@@ -2033,7 +2030,7 @@ def eval(self, data: DMatrix, name: str = 'eval', iteration: int = 0) -> str:
20332030
result: str
20342031
Evaluation result string.
20352032
"""
2036-
self._validate_features(data)
2033+
self._validate_dmatrix_features(data)
20372034
return self.eval_set([(data, name)], iteration)
20382035

20392036
# pylint: disable=too-many-function-args
@@ -2136,7 +2133,7 @@ def predict(
21362133
if not isinstance(data, DMatrix):
21372134
raise TypeError('Expecting data to be a DMatrix object, got: ', type(data))
21382135
if validate_features:
2139-
self._validate_features(data)
2136+
self._validate_dmatrix_features(data)
21402137
iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range)
21412138
args = {
21422139
"type": 0,
@@ -2184,8 +2181,8 @@ def inplace_predict(
21842181
base_margin: Any = None,
21852182
strict_shape: bool = False
21862183
) -> NumpyOrCupy:
2187-
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction does not
2188-
cache the prediction result.
2184+
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction
2185+
does not cache the prediction result.
21892186
21902187
Calling only ``inplace_predict`` in multiple threads is safe and lock
21912188
free. But the safety does not hold when used in conjunction with other
@@ -2273,18 +2270,22 @@ def inplace_predict(
22732270
)
22742271

22752272
from .data import (
2276-
_is_pandas_df,
2277-
_transform_pandas_df,
2273+
_array_interface,
22782274
_is_cudf_df,
22792275
_is_cupy_array,
2280-
_array_interface,
2276+
_is_pandas_df,
2277+
_transform_pandas_df,
22812278
)
2279+
22822280
enable_categorical = _has_categorical(self, data)
22832281
if _is_pandas_df(data):
2284-
data, _, _ = _transform_pandas_df(data, enable_categorical)
2282+
data, fns, _ = _transform_pandas_df(data, enable_categorical)
2283+
if validate_features:
2284+
self._validate_features(fns)
22852285

22862286
if isinstance(data, np.ndarray):
22872287
from .data import _ensure_np_dtype
2288+
22882289
data, _ = _ensure_np_dtype(data, data.dtype)
22892290
_check_call(
22902291
_LIB.XGBoosterPredictFromDense(
@@ -2334,10 +2335,13 @@ def inplace_predict(
23342335
return _prediction_output(shape, dims, preds, True)
23352336
if _is_cudf_df(data):
23362337
from .data import _cudf_array_interfaces, _transform_cudf_df
2337-
data, cat_codes, _, _ = _transform_cudf_df(
2338+
2339+
data, cat_codes, fns, _ = _transform_cudf_df(
23382340
data, None, None, enable_categorical
23392341
)
23402342
interfaces_str = _cudf_array_interfaces(data, cat_codes)
2343+
if validate_features:
2344+
self._validate_features(fns)
23412345
_check_call(
23422346
_LIB.XGBoosterPredictFromCudaColumnar(
23432347
self.handle,
@@ -2723,40 +2727,55 @@ def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame:
27232727
# pylint: disable=no-member
27242728
return df.sort(['Tree', 'Node']).reset_index(drop=True)
27252729

2726-
def _validate_features(self, data: DMatrix) -> None:
2727-
"""
2728-
Validate Booster and data's feature_names are identical.
2729-
Set feature_names and feature_types from DMatrix
2730-
"""
2730+
def _validate_dmatrix_features(self, data: DMatrix) -> None:
27312731
if data.num_row() == 0:
27322732
return
27332733

2734+
fn = data.feature_names
2735+
ft = data.feature_types
2736+
# Be consistent with versions before 1.7, "validate" actually modifies the
2737+
# booster.
27342738
if self.feature_names is None:
2735-
self.feature_names = data.feature_names
2736-
self.feature_types = data.feature_types
2737-
if data.feature_names is None and self.feature_names is not None:
2739+
self.feature_names = fn
2740+
if self.feature_types is None:
2741+
self.feature_types = ft
2742+
2743+
self._validate_features(fn)
2744+
2745+
def _validate_features(self, feature_names: Optional[FeatureNames]) -> None:
2746+
if self.feature_names is None:
2747+
return
2748+
2749+
if feature_names is None and self.feature_names is not None:
27382750
raise ValueError(
2739-
"training data did not have the following fields: " +
2740-
", ".join(self.feature_names)
2751+
"training data did not have the following fields: "
2752+
+ ", ".join(self.feature_names)
27412753
)
2742-
# Booster can't accept data with different feature names
2743-
if self.feature_names != data.feature_names:
2744-
dat_missing = set(cast(FeatureNames, self.feature_names)) - \
2745-
set(cast(FeatureNames, data.feature_names))
2746-
my_missing = set(cast(FeatureNames, data.feature_names)) - \
2747-
set(cast(FeatureNames, self.feature_names))
27482754

2749-
msg = 'feature_names mismatch: {0} {1}'
2755+
if self.feature_names != feature_names:
2756+
dat_missing = set(cast(FeatureNames, self.feature_names)) - set(
2757+
cast(FeatureNames, feature_names)
2758+
)
2759+
my_missing = set(cast(FeatureNames, feature_names)) - set(
2760+
cast(FeatureNames, self.feature_names)
2761+
)
2762+
2763+
msg = "feature_names mismatch: {0} {1}"
27502764

27512765
if dat_missing:
2752-
msg += ('\nexpected ' + ', '.join(
2753-
str(s) for s in dat_missing) + ' in input data')
2766+
msg += (
2767+
"\nexpected "
2768+
+ ", ".join(str(s) for s in dat_missing)
2769+
+ " in input data"
2770+
)
27542771

27552772
if my_missing:
2756-
msg += ('\ntraining data did not have the following fields: ' +
2757-
', '.join(str(s) for s in my_missing))
2773+
msg += (
2774+
"\ntraining data did not have the following fields: "
2775+
+ ", ".join(str(s) for s in my_missing)
2776+
)
27582777

2759-
raise ValueError(msg.format(self.feature_names, data.feature_names))
2778+
raise ValueError(msg.format(self.feature_names, feature_names))
27602779

27612780
def get_split_value_histogram(
27622781
self,

tests/python/test_with_sklearn.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib.util
33
import json
44
import os
5+
import random
56
import tempfile
67
from typing import Callable, Optional
78

@@ -998,34 +999,40 @@ def test_deprecate_position_arg():
998999
def test_pandas_input():
9991000
import pandas as pd
10001001
from sklearn.calibration import CalibratedClassifierCV
1002+
10011003
rng = np.random.RandomState(1994)
10021004

10031005
kRows = 100
10041006
kCols = 6
10051007

1006-
X = rng.randint(low=0, high=2, size=kRows*kCols)
1008+
X = rng.randint(low=0, high=2, size=kRows * kCols)
10071009
X = X.reshape(kRows, kCols)
10081010

10091011
df = pd.DataFrame(X)
10101012
feature_names = []
10111013
for i in range(1, kCols):
1012-
feature_names += ['k'+str(i)]
1014+
feature_names += ["k" + str(i)]
10131015

1014-
df.columns = ['status'] + feature_names
1016+
df.columns = ["status"] + feature_names
10151017

1016-
target = df['status']
1017-
train = df.drop(columns=['status'])
1018+
target = df["status"]
1019+
train = df.drop(columns=["status"])
10181020
model = xgb.XGBClassifier()
10191021
model.fit(train, target)
10201022
np.testing.assert_equal(model.feature_names_in_, np.array(feature_names))
10211023

1022-
clf_isotonic = CalibratedClassifierCV(model,
1023-
cv='prefit', method='isotonic')
1024+
columns = list(train.columns)
1025+
random.shuffle(columns, lambda: 0.1)
1026+
df_incorrect = df[columns]
1027+
with pytest.raises(ValueError):
1028+
model.predict(df_incorrect)
1029+
1030+
clf_isotonic = CalibratedClassifierCV(model, cv="prefit", method="isotonic")
10241031
clf_isotonic.fit(train, target)
1025-
assert isinstance(clf_isotonic.calibrated_classifiers_[0].base_estimator,
1026-
xgb.XGBClassifier)
1027-
np.testing.assert_allclose(np.array(clf_isotonic.classes_),
1028-
np.array([0, 1]))
1032+
assert isinstance(
1033+
clf_isotonic.calibrated_classifiers_[0].base_estimator, xgb.XGBClassifier
1034+
)
1035+
np.testing.assert_allclose(np.array(clf_isotonic.classes_), np.array([0, 1]))
10291036

10301037

10311038
def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):

0 commit comments

Comments
 (0)