@@ -406,10 +406,7 @@ def c_array(
406
406
407
407
408
408
def _prediction_output (
409
- shape : CNumericPtr ,
410
- dims : c_bst_ulong ,
411
- predts : CFloatPtr ,
412
- is_cuda : bool
409
+ shape : CNumericPtr , dims : c_bst_ulong , predts : CFloatPtr , is_cuda : bool
413
410
) -> NumpyOrCupy :
414
411
arr_shape = ctypes2numpy (shape , dims .value , np .uint64 )
415
412
length = int (np .prod (arr_shape ))
@@ -1555,7 +1552,7 @@ def __init__(
1555
1552
ctypes .byref (self .handle )))
1556
1553
for d in cache :
1557
1554
# Validate feature only after the feature names are saved into booster.
1558
- self ._validate_features (d )
1555
+ self ._validate_dmatrix_features (d )
1559
1556
1560
1557
if isinstance (model_file , Booster ):
1561
1558
assert self .handle is not None
@@ -1914,7 +1911,7 @@ def update(
1914
1911
"""
1915
1912
if not isinstance (dtrain , DMatrix ):
1916
1913
raise TypeError (f"invalid training matrix: { type (dtrain ).__name__ } " )
1917
- self ._validate_features (dtrain )
1914
+ self ._validate_dmatrix_features (dtrain )
1918
1915
1919
1916
if fobj is None :
1920
1917
_check_call (_LIB .XGBoosterUpdateOneIter (self .handle ,
@@ -1946,7 +1943,7 @@ def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None:
1946
1943
)
1947
1944
if not isinstance (dtrain , DMatrix ):
1948
1945
raise TypeError (f"invalid training matrix: { type (dtrain ).__name__ } " )
1949
- self ._validate_features (dtrain )
1946
+ self ._validate_dmatrix_features (dtrain )
1950
1947
1951
1948
_check_call (_LIB .XGBoosterBoostOneIter (self .handle , dtrain .handle ,
1952
1949
c_array (ctypes .c_float , grad ),
@@ -1982,7 +1979,7 @@ def eval_set(
1982
1979
raise TypeError (f"expected DMatrix, got { type (d [0 ]).__name__ } " )
1983
1980
if not isinstance (d [1 ], str ):
1984
1981
raise TypeError (f"expected string, got { type (d [1 ]).__name__ } " )
1985
- self ._validate_features (d [0 ])
1982
+ self ._validate_dmatrix_features (d [0 ])
1986
1983
1987
1984
dmats = c_array (ctypes .c_void_p , [d [0 ].handle for d in evals ])
1988
1985
evnames = c_array (ctypes .c_char_p , [c_str (d [1 ]) for d in evals ])
@@ -2033,7 +2030,7 @@ def eval(self, data: DMatrix, name: str = 'eval', iteration: int = 0) -> str:
2033
2030
result: str
2034
2031
Evaluation result string.
2035
2032
"""
2036
- self ._validate_features (data )
2033
+ self ._validate_dmatrix_features (data )
2037
2034
return self .eval_set ([(data , name )], iteration )
2038
2035
2039
2036
# pylint: disable=too-many-function-args
@@ -2136,7 +2133,7 @@ def predict(
2136
2133
if not isinstance (data , DMatrix ):
2137
2134
raise TypeError ('Expecting data to be a DMatrix object, got: ' , type (data ))
2138
2135
if validate_features :
2139
- self ._validate_features (data )
2136
+ self ._validate_dmatrix_features (data )
2140
2137
iteration_range = _convert_ntree_limit (self , ntree_limit , iteration_range )
2141
2138
args = {
2142
2139
"type" : 0 ,
@@ -2184,8 +2181,8 @@ def inplace_predict(
2184
2181
base_margin : Any = None ,
2185
2182
strict_shape : bool = False
2186
2183
) -> NumpyOrCupy :
2187
- """Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction does not
2188
- cache the prediction result.
2184
+ """Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction
2185
+ does not cache the prediction result.
2189
2186
2190
2187
Calling only ``inplace_predict`` in multiple threads is safe and lock
2191
2188
free. But the safety does not hold when used in conjunction with other
@@ -2273,18 +2270,22 @@ def inplace_predict(
2273
2270
)
2274
2271
2275
2272
from .data import (
2276
- _is_pandas_df ,
2277
- _transform_pandas_df ,
2273
+ _array_interface ,
2278
2274
_is_cudf_df ,
2279
2275
_is_cupy_array ,
2280
- _array_interface ,
2276
+ _is_pandas_df ,
2277
+ _transform_pandas_df ,
2281
2278
)
2279
+
2282
2280
enable_categorical = _has_categorical (self , data )
2283
2281
if _is_pandas_df (data ):
2284
- data , _ , _ = _transform_pandas_df (data , enable_categorical )
2282
+ data , fns , _ = _transform_pandas_df (data , enable_categorical )
2283
+ if validate_features :
2284
+ self ._validate_features (fns )
2285
2285
2286
2286
if isinstance (data , np .ndarray ):
2287
2287
from .data import _ensure_np_dtype
2288
+
2288
2289
data , _ = _ensure_np_dtype (data , data .dtype )
2289
2290
_check_call (
2290
2291
_LIB .XGBoosterPredictFromDense (
@@ -2334,10 +2335,13 @@ def inplace_predict(
2334
2335
return _prediction_output (shape , dims , preds , True )
2335
2336
if _is_cudf_df (data ):
2336
2337
from .data import _cudf_array_interfaces , _transform_cudf_df
2337
- data , cat_codes , _ , _ = _transform_cudf_df (
2338
+
2339
+ data , cat_codes , fns , _ = _transform_cudf_df (
2338
2340
data , None , None , enable_categorical
2339
2341
)
2340
2342
interfaces_str = _cudf_array_interfaces (data , cat_codes )
2343
+ if validate_features :
2344
+ self ._validate_features (fns )
2341
2345
_check_call (
2342
2346
_LIB .XGBoosterPredictFromCudaColumnar (
2343
2347
self .handle ,
@@ -2723,40 +2727,55 @@ def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame:
2723
2727
# pylint: disable=no-member
2724
2728
return df .sort (['Tree' , 'Node' ]).reset_index (drop = True )
2725
2729
2726
- def _validate_features (self , data : DMatrix ) -> None :
2727
- """
2728
- Validate Booster and data's feature_names are identical.
2729
- Set feature_names and feature_types from DMatrix
2730
- """
2730
+ def _validate_dmatrix_features (self , data : DMatrix ) -> None :
2731
2731
if data .num_row () == 0 :
2732
2732
return
2733
2733
2734
+ fn = data .feature_names
2735
+ ft = data .feature_types
2736
+ # Be consistent with versions before 1.7, "validate" actually modifies the
2737
+ # booster.
2734
2738
if self .feature_names is None :
2735
- self .feature_names = data .feature_names
2736
- self .feature_types = data .feature_types
2737
- if data .feature_names is None and self .feature_names is not None :
2739
+ self .feature_names = fn
2740
+ if self .feature_types is None :
2741
+ self .feature_types = ft
2742
+
2743
+ self ._validate_features (fn )
2744
+
2745
+ def _validate_features (self , feature_names : Optional [FeatureNames ]) -> None :
2746
+ if self .feature_names is None :
2747
+ return
2748
+
2749
+ if feature_names is None and self .feature_names is not None :
2738
2750
raise ValueError (
2739
- "training data did not have the following fields: " +
2740
- ", " .join (self .feature_names )
2751
+ "training data did not have the following fields: "
2752
+ + ", " .join (self .feature_names )
2741
2753
)
2742
- # Booster can't accept data with different feature names
2743
- if self .feature_names != data .feature_names :
2744
- dat_missing = set (cast (FeatureNames , self .feature_names )) - \
2745
- set (cast (FeatureNames , data .feature_names ))
2746
- my_missing = set (cast (FeatureNames , data .feature_names )) - \
2747
- set (cast (FeatureNames , self .feature_names ))
2748
2754
2749
- msg = 'feature_names mismatch: {0} {1}'
2755
+ if self .feature_names != feature_names :
2756
+ dat_missing = set (cast (FeatureNames , self .feature_names )) - set (
2757
+ cast (FeatureNames , feature_names )
2758
+ )
2759
+ my_missing = set (cast (FeatureNames , feature_names )) - set (
2760
+ cast (FeatureNames , self .feature_names )
2761
+ )
2762
+
2763
+ msg = "feature_names mismatch: {0} {1}"
2750
2764
2751
2765
if dat_missing :
2752
- msg += ('\n expected ' + ', ' .join (
2753
- str (s ) for s in dat_missing ) + ' in input data' )
2766
+ msg += (
2767
+ "\n expected "
2768
+ + ", " .join (str (s ) for s in dat_missing )
2769
+ + " in input data"
2770
+ )
2754
2771
2755
2772
if my_missing :
2756
- msg += ('\n training data did not have the following fields: ' +
2757
- ', ' .join (str (s ) for s in my_missing ))
2773
+ msg += (
2774
+ "\n training data did not have the following fields: "
2775
+ + ", " .join (str (s ) for s in my_missing )
2776
+ )
2758
2777
2759
- raise ValueError (msg .format (self .feature_names , data . feature_names ))
2778
+ raise ValueError (msg .format (self .feature_names , feature_names ))
2760
2779
2761
2780
def get_split_value_histogram (
2762
2781
self ,
0 commit comments