Skip to content

Commit d42575f

Browse files
authored
BUG: read_csv with engine=pyarrow and numpy-nullable dtype (#62053)
1 parent bb10b27 commit d42575f

File tree

5 files changed

+124
-22
lines changed

5 files changed

+124
-22
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ I/O
857857
- Bug in :meth:`read_csv` raising ``TypeError`` when ``index_col`` is specified and ``na_values`` is a dict containing the key ``None``. (:issue:`57547`)
858858
- Bug in :meth:`read_csv` raising ``TypeError`` when ``nrows`` and ``iterator`` are specified without specifying a ``chunksize``. (:issue:`59079`)
859859
- Bug in :meth:`read_csv` where the order of the ``na_values`` makes an inconsistency when ``na_values`` is a list non-string values. (:issue:`59303`)
860+
- Bug in :meth:`read_csv` with ``engine="pyarrow"`` and ``dtype="Int64"`` losing precision (:issue:`56136`)
860861
- Bug in :meth:`read_excel` raising ``ValueError`` when passing array of boolean values when ``dtype="boolean"``. (:issue:`58159`)
861862
- Bug in :meth:`read_html` where ``rowspan`` in header row causes incorrect conversion to ``DataFrame``. (:issue:`60210`)
862863
- Bug in :meth:`read_json` ignoring the given ``dtype`` when ``engine="pyarrow"`` (:issue:`59516`)

pandas/io/_util.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,23 @@
1616
)
1717
from pandas.compat._optional import import_optional_dependency
1818

19+
from pandas.core.dtypes.common import pandas_dtype
20+
1921
import pandas as pd
2022

2123
if TYPE_CHECKING:
22-
from collections.abc import Callable
24+
from collections.abc import (
25+
Callable,
26+
Hashable,
27+
Sequence,
28+
)
2329

2430
import pyarrow
2531

26-
from pandas._typing import DtypeBackend
32+
from pandas._typing import (
33+
DtypeArg,
34+
DtypeBackend,
35+
)
2736

2837

2938
def _arrow_dtype_mapping() -> dict:
@@ -64,6 +73,8 @@ def arrow_table_to_pandas(
6473
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
6574
null_to_int64: bool = False,
6675
to_pandas_kwargs: dict | None = None,
76+
dtype: DtypeArg | None = None,
77+
names: Sequence[Hashable] | None = None,
6778
) -> pd.DataFrame:
6879
pa = import_optional_dependency("pyarrow")
6980

@@ -82,12 +93,77 @@ def arrow_table_to_pandas(
8293
elif using_string_dtype():
8394
if pa_version_under19p0:
8495
types_mapper = _arrow_string_types_mapper()
96+
elif dtype is not None:
97+
# GH#56136 Avoid lossy conversion to float64
98+
# We'll convert to numpy below if
99+
types_mapper = {
100+
pa.int8(): pd.Int8Dtype(),
101+
pa.int16(): pd.Int16Dtype(),
102+
pa.int32(): pd.Int32Dtype(),
103+
pa.int64(): pd.Int64Dtype(),
104+
}.get
85105
else:
86106
types_mapper = None
87107
elif dtype_backend is lib.no_default or dtype_backend == "numpy":
88-
types_mapper = None
108+
if dtype is not None:
109+
# GH#56136 Avoid lossy conversion to float64
110+
# We'll convert to numpy below if
111+
types_mapper = {
112+
pa.int8(): pd.Int8Dtype(),
113+
pa.int16(): pd.Int16Dtype(),
114+
pa.int32(): pd.Int32Dtype(),
115+
pa.int64(): pd.Int64Dtype(),
116+
}.get
117+
else:
118+
types_mapper = None
89119
else:
90120
raise NotImplementedError
91121

92122
df = table.to_pandas(types_mapper=types_mapper, **to_pandas_kwargs)
123+
return _post_convert_dtypes(df, dtype_backend, dtype, names)
124+
125+
126+
def _post_convert_dtypes(
127+
df: pd.DataFrame,
128+
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault,
129+
dtype: DtypeArg | None,
130+
names: Sequence[Hashable] | None,
131+
) -> pd.DataFrame:
132+
if dtype is not None and (
133+
dtype_backend is lib.no_default or dtype_backend == "numpy"
134+
):
135+
# GH#56136 apply any user-provided dtype, and convert any IntegerDtype
136+
# columns the user didn't explicitly ask for.
137+
if isinstance(dtype, dict):
138+
if names is not None:
139+
df.columns = names
140+
141+
cmp_dtypes = {
142+
pd.Int8Dtype(),
143+
pd.Int16Dtype(),
144+
pd.Int32Dtype(),
145+
pd.Int64Dtype(),
146+
}
147+
for col in df.columns:
148+
if col not in dtype and df[col].dtype in cmp_dtypes:
149+
# Any key that the user didn't explicitly specify
150+
# that got converted to IntegerDtype now gets converted
151+
# to numpy dtype.
152+
dtype[col] = df[col].dtype.numpy_dtype
153+
154+
# Ignore non-existent columns from dtype mapping
155+
# like other parsers do
156+
dtype = {
157+
key: pandas_dtype(dtype[key]) for key in dtype if key in df.columns
158+
}
159+
160+
else:
161+
dtype = pandas_dtype(dtype)
162+
163+
try:
164+
df = df.astype(dtype)
165+
except TypeError as err:
166+
# GH#44901 reraise to keep api consistent
167+
raise ValueError(str(err)) from err
168+
93169
return df

pandas/io/parsers/arrow_parser_wrapper.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
)
1212
from pandas.util._exceptions import find_stack_level
1313

14-
from pandas.core.dtypes.common import pandas_dtype
14+
from pandas.core.dtypes.common import (
15+
pandas_dtype,
16+
)
1517
from pandas.core.dtypes.inference import is_integer
1618

1719
from pandas.io._util import arrow_table_to_pandas
1820
from pandas.io.parsers.base_parser import ParserBase
1921

2022
if TYPE_CHECKING:
23+
import pyarrow as pa
24+
2125
from pandas._typing import ReadBuffer
2226

2327
from pandas import DataFrame
@@ -162,13 +166,12 @@ def _get_convert_options(self):
162166

163167
return convert_options
164168

165-
def _adjust_column_names(self, frame: DataFrame) -> tuple[DataFrame, bool]:
166-
num_cols = len(frame.columns)
169+
def _adjust_column_names(self, table: pa.Table) -> bool:
170+
num_cols = len(table.columns)
167171
multi_index_named = True
168172
if self.header is None:
169173
if self.names is None:
170-
if self.header is None:
171-
self.names = range(num_cols)
174+
self.names = range(num_cols)
172175
if len(self.names) != num_cols:
173176
# usecols is passed through to pyarrow, we only handle index col here
174177
# The only way self.names is not the same length as number of cols is
@@ -177,8 +180,7 @@ def _adjust_column_names(self, frame: DataFrame) -> tuple[DataFrame, bool]:
177180
columns_prefix = [str(x) for x in range(num_cols - len(self.names))]
178181
self.names = columns_prefix + self.names
179182
multi_index_named = False
180-
frame.columns = self.names
181-
return frame, multi_index_named
183+
return multi_index_named
182184

183185
def _finalize_index(self, frame: DataFrame, multi_index_named: bool) -> DataFrame:
184186
if self.index_col is not None:
@@ -227,21 +229,23 @@ def _finalize_dtype(self, frame: DataFrame) -> DataFrame:
227229
raise ValueError(str(err)) from err
228230
return frame
229231

230-
def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
232+
def _finalize_pandas_output(
233+
self, frame: DataFrame, multi_index_named: bool
234+
) -> DataFrame:
231235
"""
232236
Processes data read in based on kwargs.
233237
234238
Parameters
235239
----------
236-
frame: DataFrame
240+
frame : DataFrame
237241
The DataFrame to process.
242+
multi_index_named : bool
238243
239244
Returns
240245
-------
241246
DataFrame
242247
The processed DataFrame.
243248
"""
244-
frame, multi_index_named = self._adjust_column_names(frame)
245249
frame = self._do_date_conversions(frame.columns, frame)
246250
frame = self._finalize_index(frame, multi_index_named)
247251
frame = self._finalize_dtype(frame)
@@ -299,14 +303,23 @@ def read(self) -> DataFrame:
299303

300304
table = table.cast(new_schema)
301305

306+
multi_index_named = self._adjust_column_names(table)
307+
302308
with warnings.catch_warnings():
303309
warnings.filterwarnings(
304310
"ignore",
305311
"make_block is deprecated",
306312
DeprecationWarning,
307313
)
308314
frame = arrow_table_to_pandas(
309-
table, dtype_backend=dtype_backend, null_to_int64=True
315+
table,
316+
dtype_backend=dtype_backend,
317+
null_to_int64=True,
318+
dtype=self.dtype,
319+
names=self.names,
310320
)
311321

312-
return self._finalize_pandas_output(frame)
322+
if self.header is None:
323+
frame.columns = self.names
324+
325+
return self._finalize_pandas_output(frame, multi_index_named)

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,6 @@ def test_dtype_backend_pyarrow(all_parsers, request):
518518
tm.assert_frame_equal(result, expected)
519519

520520

521-
# pyarrow engine failing:
522-
# https://github.com/pandas-dev/pandas/issues/56136
523-
@pytest.mark.usefixtures("pyarrow_xfail")
524521
def test_ea_int_avoid_overflow(all_parsers):
525522
# GH#32134
526523
parser = all_parsers
@@ -594,7 +591,6 @@ def test_string_inference_object_dtype(all_parsers, dtype, using_infer_string):
594591
tm.assert_frame_equal(result, expected)
595592

596593

597-
@xfail_pyarrow
598594
def test_accurate_parsing_of_large_integers(all_parsers):
599595
# GH#52505
600596
data = """SYMBOL,MOMENT,ID,ID_DEAL

pandas/tests/io/parser/test_na_values.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,16 @@ def test_inf_na_values_with_int_index(all_parsers):
670670
tm.assert_frame_equal(out, expected)
671671

672672

673-
@xfail_pyarrow # mismatched shape
674673
@pytest.mark.parametrize("na_filter", [True, False])
675-
def test_na_values_with_dtype_str_and_na_filter(all_parsers, na_filter):
674+
def test_na_values_with_dtype_str_and_na_filter(
675+
all_parsers, na_filter, using_infer_string, request
676+
):
676677
# see gh-20377
677678
parser = all_parsers
679+
if parser.engine == "pyarrow" and (na_filter is False or not using_infer_string):
680+
mark = pytest.mark.xfail(reason="mismatched shape")
681+
request.applymarker(mark)
682+
678683
data = "a,b,c\n1,,3\n4,5,6"
679684

680685
# na_filter=True --> missing value becomes NaN.
@@ -798,7 +803,18 @@ def test_bool_and_nan_to_int(all_parsers):
798803
True
799804
False
800805
"""
801-
with pytest.raises(ValueError, match="convert|NoneType"):
806+
msg = (
807+
"cannot safely convert passed user dtype of int(64|32) for "
808+
"<class 'numpy.bool_?'> dtyped data in column 0 due to NA values"
809+
)
810+
if parser.engine == "python":
811+
msg = "Unable to convert column 0 to type int(64|32)"
812+
elif parser.engine == "pyarrow":
813+
msg = (
814+
r"int\(\) argument must be a string, a bytes-like object or a "
815+
"real number, not 'NoneType"
816+
)
817+
with pytest.raises(ValueError, match=msg):
802818
parser.read_csv(StringIO(data), dtype="int")
803819

804820

0 commit comments

Comments
 (0)