Skip to content

Commit e824b18

Browse files
authored
[backport] Support pandas 2.1.0. (dmlc#9557) (dmlc#9655)
1 parent 66ee89d commit e824b18

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

python-package/xgboost/data.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,6 @@ def pandas_feature_info(
317317
) -> Tuple[Optional[FeatureNames], Optional[FeatureTypes]]:
318318
"""Handle feature info for pandas dataframe."""
319319
import pandas as pd
320-
from pandas.api.types import is_categorical_dtype, is_sparse
321320

322321
# handle feature names
323322
if feature_names is None and meta is None:
@@ -332,10 +331,10 @@ def pandas_feature_info(
332331
if feature_types is None and meta is None:
333332
feature_types = []
334333
for dtype in data.dtypes:
335-
if is_sparse(dtype):
334+
if is_pd_sparse_dtype(dtype):
336335
feature_types.append(_pandas_dtype_mapper[dtype.subtype.name])
337336
elif (
338-
is_categorical_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
337+
is_pd_cat_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
339338
) and enable_categorical:
340339
feature_types.append(CAT_T)
341340
else:
@@ -345,18 +344,13 @@ def pandas_feature_info(
345344

346345
def is_nullable_dtype(dtype: PandasDType) -> bool:
347346
"""Whether dtype is a pandas nullable type."""
348-
from pandas.api.types import (
349-
is_bool_dtype,
350-
is_categorical_dtype,
351-
is_float_dtype,
352-
is_integer_dtype,
353-
)
347+
from pandas.api.types import is_bool_dtype, is_float_dtype, is_integer_dtype
354348

355349
is_int = is_integer_dtype(dtype) and dtype.name in pandas_nullable_mapper
356350
# np.bool has alias `bool`, while pd.BooleanDtype has `boolean`.
357351
is_bool = is_bool_dtype(dtype) and dtype.name == "boolean"
358352
is_float = is_float_dtype(dtype) and dtype.name in pandas_nullable_mapper
359-
return is_int or is_bool or is_float or is_categorical_dtype(dtype)
353+
return is_int or is_bool or is_float or is_pd_cat_dtype(dtype)
360354

361355

362356
def is_pa_ext_dtype(dtype: Any) -> bool:
@@ -371,17 +365,48 @@ def is_pa_ext_categorical_dtype(dtype: Any) -> bool:
371365
)
372366

373367

368+
def is_pd_cat_dtype(dtype: PandasDType) -> bool:
369+
"""Wrapper for testing pandas category type."""
370+
import pandas as pd
371+
372+
if hasattr(pd.util, "version") and hasattr(pd.util.version, "Version"):
373+
Version = pd.util.version.Version
374+
if Version(pd.__version__) >= Version("2.1.0"):
375+
from pandas import CategoricalDtype
376+
377+
return isinstance(dtype, CategoricalDtype)
378+
379+
from pandas.api.types import is_categorical_dtype
380+
381+
return is_categorical_dtype(dtype)
382+
383+
384+
def is_pd_sparse_dtype(dtype: PandasDType) -> bool:
385+
"""Wrapper for testing pandas sparse type."""
386+
import pandas as pd
387+
388+
if hasattr(pd.util, "version") and hasattr(pd.util.version, "Version"):
389+
Version = pd.util.version.Version
390+
if Version(pd.__version__) >= Version("2.1.0"):
391+
from pandas import SparseDtype
392+
393+
return isinstance(dtype, SparseDtype)
394+
395+
from pandas.api.types import is_sparse
396+
397+
return is_sparse(dtype)
398+
399+
374400
def pandas_cat_null(data: DataFrame) -> DataFrame:
375401
"""Handle categorical dtype and nullable extension types from pandas."""
376402
import pandas as pd
377-
from pandas.api.types import is_categorical_dtype
378403

379404
# handle category codes and nullable.
380405
cat_columns = []
381406
nul_columns = []
382407
# avoid an unnecessary conversion if possible
383408
for col, dtype in zip(data.columns, data.dtypes):
384-
if is_categorical_dtype(dtype):
409+
if is_pd_cat_dtype(dtype):
385410
cat_columns.append(col)
386411
elif is_pa_ext_categorical_dtype(dtype):
387412
raise ValueError(
@@ -398,7 +423,7 @@ def pandas_cat_null(data: DataFrame) -> DataFrame:
398423
transformed = data
399424

400425
def cat_codes(ser: pd.Series) -> pd.Series:
401-
if is_categorical_dtype(ser.dtype):
426+
if is_pd_cat_dtype(ser.dtype):
402427
return ser.cat.codes
403428
assert is_pa_ext_categorical_dtype(ser.dtype)
404429
# Not yet supported, the index is not ordered for some reason. Alternately:
@@ -454,14 +479,12 @@ def _transform_pandas_df(
454479
meta: Optional[str] = None,
455480
meta_type: Optional[NumpyDType] = None,
456481
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
457-
from pandas.api.types import is_categorical_dtype, is_sparse
458-
459482
pyarrow_extension = False
460483
for dtype in data.dtypes:
461484
if not (
462485
(dtype.name in _pandas_dtype_mapper)
463-
or is_sparse(dtype)
464-
or (is_categorical_dtype(dtype) and enable_categorical)
486+
or is_pd_sparse_dtype(dtype)
487+
or (is_pd_cat_dtype(dtype) and enable_categorical)
465488
or is_pa_ext_dtype(dtype)
466489
):
467490
_invalid_dataframe_dtype(data)
@@ -515,9 +538,8 @@ def _meta_from_pandas_series(
515538
) -> None:
516539
"""Help transform pandas series for meta data like labels"""
517540
data = data.values.astype("float")
518-
from pandas.api.types import is_sparse
519541

520-
if is_sparse(data):
542+
if is_pd_sparse_dtype(getattr(data, "dtype", data)):
521543
data = data.to_dense() # type: ignore
522544
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
523545
_meta_from_numpy(data, name, dtype, handle)
@@ -539,13 +561,11 @@ def _from_pandas_series(
539561
feature_names: Optional[FeatureNames],
540562
feature_types: Optional[FeatureTypes],
541563
) -> DispatchedDataBackendReturnType:
542-
from pandas.api.types import is_categorical_dtype
543-
544564
if (data.dtype.name not in _pandas_dtype_mapper) and not (
545-
is_categorical_dtype(data.dtype) and enable_categorical
565+
is_pd_cat_dtype(data.dtype) and enable_categorical
546566
):
547567
_invalid_dataframe_dtype(data)
548-
if enable_categorical and is_categorical_dtype(data.dtype):
568+
if enable_categorical and is_pd_cat_dtype(data.dtype):
549569
data = data.cat.codes
550570
return _from_numpy_array(
551571
data.values.reshape(data.shape[0], 1).astype("float"),

tests/python/test_with_pandas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_pandas_weight(self):
211211
y = np.random.randn(kRows)
212212
w = np.random.uniform(size=kRows).astype(np.float32)
213213
w_pd = pd.DataFrame(w)
214-
data = xgb.DMatrix(X, y, w_pd)
214+
data = xgb.DMatrix(X, y, weight=w_pd)
215215

216216
assert data.num_row() == kRows
217217
assert data.num_col() == kCols
@@ -301,14 +301,14 @@ def test_cv_as_pandas(self):
301301

302302
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
303303
def test_nullable_type(self, DMatrixT) -> None:
304-
from pandas.api.types import is_categorical_dtype
304+
from xgboost.data import is_pd_cat_dtype
305305

306306
for orig, df in pd_dtypes():
307307
if hasattr(df.dtypes, "__iter__"):
308-
enable_categorical = any(is_categorical_dtype for dtype in df.dtypes)
308+
enable_categorical = any(is_pd_cat_dtype(dtype) for dtype in df.dtypes)
309309
else:
310310
# series
311-
enable_categorical = is_categorical_dtype(df.dtype)
311+
enable_categorical = is_pd_cat_dtype(df.dtype)
312312

313313
f0_orig = orig[orig.columns[0]] if isinstance(orig, pd.DataFrame) else orig
314314
f0 = df[df.columns[0]] if isinstance(df, pd.DataFrame) else df

0 commit comments

Comments
 (0)