Skip to content

Commit a4c6cde

Browse files
authored
[BP] Fix boolean array for arrow-backed DF. (dmlc#10527) (dmlc#10901)
1 parent 7f87b9e commit a4c6cde

File tree

2 files changed

+21
-44
lines changed

2 files changed

+21
-44
lines changed

python-package/xgboost/data.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def pandas_pa_type(ser: Any) -> np.ndarray:
458458
# combine_chunks takes the most significant amount of time
459459
chunk: pa.Array = aa.combine_chunks()
460460
# When there's null value, we have to use copy
461-
zero_copy = chunk.null_count == 0
461+
zero_copy = chunk.null_count == 0 and not pa.types.is_boolean(chunk.type)
462462
# Alternately, we can use chunk.buffers(), which returns a list of buffers and
463463
# we need to concatenate them ourselves.
464464
# FIXME(jiamingy): Is there a better way to access the arrow buffer along with
@@ -825,37 +825,9 @@ def _arrow_transform(data: DataType) -> Any:
825825

826826
data = cast(pa.Table, data)
827827

828-
def type_mapper(dtype: pa.DataType) -> Optional[str]:
829-
"""Maps pyarrow type to pandas arrow extension type."""
830-
if pa.types.is_int8(dtype):
831-
return pd.ArrowDtype(pa.int8())
832-
if pa.types.is_int16(dtype):
833-
return pd.ArrowDtype(pa.int16())
834-
if pa.types.is_int32(dtype):
835-
return pd.ArrowDtype(pa.int32())
836-
if pa.types.is_int64(dtype):
837-
return pd.ArrowDtype(pa.int64())
838-
if pa.types.is_uint8(dtype):
839-
return pd.ArrowDtype(pa.uint8())
840-
if pa.types.is_uint16(dtype):
841-
return pd.ArrowDtype(pa.uint16())
842-
if pa.types.is_uint32(dtype):
843-
return pd.ArrowDtype(pa.uint32())
844-
if pa.types.is_uint64(dtype):
845-
return pd.ArrowDtype(pa.uint64())
846-
if pa.types.is_float16(dtype):
847-
return pd.ArrowDtype(pa.float16())
848-
if pa.types.is_float32(dtype):
849-
return pd.ArrowDtype(pa.float32())
850-
if pa.types.is_float64(dtype):
851-
return pd.ArrowDtype(pa.float64())
852-
if pa.types.is_boolean(dtype):
853-
return pd.ArrowDtype(pa.bool_())
854-
return None
855-
856828
# For common cases, this is zero-copy, can check with:
857829
# pa.total_allocated_bytes()
858-
df = data.to_pandas(types_mapper=type_mapper)
830+
df = data.to_pandas(types_mapper=pd.ArrowDtype)
859831
return df
860832

861833

python-package/xgboost/testing/data.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,6 @@ def pd_arrow_dtypes() -> Generator:
165165

166166
# Integer
167167
dtypes = pandas_pyarrow_mapper
168-
Null: Union[float, None, Any] = np.nan
169-
orig = pd.DataFrame(
170-
{"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=np.float32
171-
)
172168
# Create a dictionary-backed dataframe, enable this when the roundtrip is
173169
# implemented in pandas/pyarrow
174170
#
@@ -191,24 +187,33 @@ def pd_arrow_dtypes() -> Generator:
191187
# pd_catcodes = pd_cat_df["f1"].cat.codes
192188
# assert pd_catcodes.equals(pa_catcodes)
193189

194-
for Null in (None, pd.NA):
190+
for Null in (None, pd.NA, 0):
195191
for dtype in dtypes:
196192
if dtype.startswith("float16") or dtype.startswith("bool"):
197193
continue
194+
# Use np.nan is a baseline
195+
orig_null = Null if not pd.isna(Null) and Null == 0 else np.nan
196+
orig = pd.DataFrame(
197+
{"f0": [1, 2, orig_null, 3], "f1": [4, 3, orig_null, 1]},
198+
dtype=np.float32,
199+
)
200+
198201
df = pd.DataFrame(
199202
{"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=dtype
200203
)
201204
yield orig, df
202205

203-
orig = pd.DataFrame(
204-
{"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]},
205-
dtype=pd.BooleanDtype(),
206-
)
207-
df = pd.DataFrame(
208-
{"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]},
209-
dtype=pd.ArrowDtype(pa.bool_()),
210-
)
211-
yield orig, df
206+
# If Null is `False`, then there's no missing value.
207+
for Null in (pd.NA, False):
208+
orig = pd.DataFrame(
209+
{"f0": [True, False, Null, True], "f1": [False, True, Null, True]},
210+
dtype=pd.BooleanDtype(),
211+
)
212+
df = pd.DataFrame(
213+
{"f0": [True, False, Null, True], "f1": [False, True, Null, True]},
214+
dtype=pd.ArrowDtype(pa.bool_()),
215+
)
216+
yield orig, df
212217

213218

214219
def check_inf(rng: RNG) -> None:

0 commit comments

Comments
 (0)