Skip to content

Commit 98bedc4

Browse files
committed
BUG: read_csv with engine=pyarrow and numpy-nullable dtype
1 parent eb489f2 commit 98bedc4

File tree

4 files changed

+74
-22
lines changed

4 files changed

+74
-22
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ I/O
814814
- Bug in :meth:`read_csv` raising ``TypeError`` when ``index_col`` is specified and ``na_values`` is a dict containing the key ``None``. (:issue:`57547`)
815815
- Bug in :meth:`read_csv` raising ``TypeError`` when ``nrows`` and ``iterator`` are specified without specifying a ``chunksize``. (:issue:`59079`)
816816
- Bug in :meth:`read_csv` where the order of the ``na_values`` makes an inconsistency when ``na_values`` is a list non-string values. (:issue:`59303`)
817+
- Bug in :meth:`read_csv` with ``engine="pyarrow"`` and ``dtype="Int64"`` losing precision (:issue:`56136`)
817818
- Bug in :meth:`read_excel` raising ``ValueError`` when passing array of boolean values when ``dtype="boolean"``. (:issue:`58159`)
818819
- Bug in :meth:`read_html` where ``rowspan`` in header row causes incorrect conversion to ``DataFrame``. (:issue:`60210`)
819820
- Bug in :meth:`read_json` ignoring the given ``dtype`` when ``engine="pyarrow"`` (:issue:`59516`)

pandas/io/parsers/arrow_parser_wrapper.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import TYPE_CHECKING
44
import warnings
55

6+
import numpy as np
7+
68
from pandas._libs import lib
79
from pandas.compat._optional import import_optional_dependency
810
from pandas.errors import (
@@ -12,8 +14,13 @@
1214
from pandas.util._exceptions import find_stack_level
1315

1416
from pandas.core.dtypes.common import pandas_dtype
17+
from pandas.core.dtypes.dtypes import (
18+
BaseMaskedDtype,
19+
)
1520
from pandas.core.dtypes.inference import is_integer
1621

22+
from pandas.core.arrays.string_ import StringDtype
23+
1724
from pandas.io._util import arrow_table_to_pandas
1825
from pandas.io.parsers.base_parser import ParserBase
1926

@@ -140,20 +147,7 @@ def handle_warning(invalid_row) -> str:
140147
"encoding": self.encoding,
141148
}
142149

143-
def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
144-
"""
145-
Processes data read in based on kwargs.
146-
147-
Parameters
148-
----------
149-
frame: DataFrame
150-
The DataFrame to process.
151-
152-
Returns
153-
-------
154-
DataFrame
155-
The processed DataFrame.
156-
"""
150+
def _finalize_column_names(self, frame: DataFrame) -> DataFrame:
157151
num_cols = len(frame.columns)
158152
multi_index_named = True
159153
if self.header is None:
@@ -196,6 +190,23 @@ def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
196190
if self.header is None and not multi_index_named:
197191
frame.index.names = [None] * len(frame.index.names)
198192

193+
return frame
194+
195+
def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
196+
"""
197+
Processes data read in based on kwargs.
198+
199+
Parameters
200+
----------
201+
frame: DataFrame
202+
The DataFrame to process.
203+
204+
Returns
205+
-------
206+
DataFrame
207+
The processed DataFrame.
208+
"""
209+
199210
if self.dtype is not None:
200211
# Ignore non-existent columns from dtype mapping
201212
# like other parsers do
@@ -282,14 +293,47 @@ def read(self) -> DataFrame:
282293

283294
table = table.cast(new_schema)
284295

296+
workaround = False
297+
pass_backend = dtype_backend
298+
if self.dtype is not None and dtype_backend != "pyarrow":
299+
# We pass dtype_backend="pyarrow" and subsequently cast
300+
# to avoid lossy conversion e.g. GH#56136
301+
workaround = True
302+
pass_backend = "numpy_nullable"
303+
285304
with warnings.catch_warnings():
286305
warnings.filterwarnings(
287306
"ignore",
288307
"make_block is deprecated",
289308
DeprecationWarning,
290309
)
291310
frame = arrow_table_to_pandas(
292-
table, dtype_backend=dtype_backend, null_to_int64=True
311+
table, dtype_backend=pass_backend, null_to_int64=True
293312
)
294313

314+
frame = self._finalize_column_names(frame)
315+
316+
if workaround and dtype_backend != "numpy_nullable":
317+
old_dtype = self.dtype
318+
if not isinstance(old_dtype, dict):
319+
# e.g. test_categorical_dtype_utf16
320+
old_dtype = dict.fromkeys(frame.columns, old_dtype)
321+
322+
# _finalize_pandas_output will call astype, but we need to make
323+
# sure all keys are populated appropriately.
324+
new_dtype = {}
325+
for key in frame.columns:
326+
ser = frame[key]
327+
if isinstance(ser.dtype, BaseMaskedDtype):
328+
new_dtype[key] = ser.dtype.numpy_dtype
329+
elif isinstance(ser.dtype, StringDtype):
330+
# We cast here in case the user passed "category" in
331+
# order to get the correct dtype.categories.dtype
332+
# e.g. test_categorical_dtype_utf16
333+
new_dtype[key] = StringDtype(na_value=np.nan)
334+
frame[key] = frame[key].astype(new_dtype[key])
335+
336+
new_dtype.update(old_dtype)
337+
self.dtype = new_dtype
338+
295339
return self._finalize_pandas_output(frame)

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,6 @@ def test_dtype_backend_pyarrow(all_parsers, request):
518518
tm.assert_frame_equal(result, expected)
519519

520520

521-
# pyarrow engine failing:
522-
# https://github.com/pandas-dev/pandas/issues/56136
523-
@pytest.mark.usefixtures("pyarrow_xfail")
524521
def test_ea_int_avoid_overflow(all_parsers):
525522
# GH#32134
526523
parser = all_parsers
@@ -594,7 +591,6 @@ def test_string_inference_object_dtype(all_parsers, dtype, using_infer_string):
594591
tm.assert_frame_equal(result, expected)
595592

596593

597-
@xfail_pyarrow
598594
def test_accurate_parsing_of_large_integers(all_parsers):
599595
# GH#52505
600596
data = """SYMBOL,MOMENT,ID,ID_DEAL

pandas/tests/io/parser/test_na_values.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,14 @@ def test_inf_na_values_with_int_index(all_parsers):
670670
tm.assert_frame_equal(out, expected)
671671

672672

673-
@xfail_pyarrow # mismatched shape
674673
@pytest.mark.parametrize("na_filter", [True, False])
675-
def test_na_values_with_dtype_str_and_na_filter(all_parsers, na_filter):
674+
def test_na_values_with_dtype_str_and_na_filter(all_parsers, na_filter, request):
676675
# see gh-20377
677676
parser = all_parsers
677+
if parser.engine == "pyarrow" and na_filter is False:
678+
mark = pytest.mark.xfail(reason="mismatched shape")
679+
request.applymarker(mark)
680+
678681
data = "a,b,c\n1,,3\n4,5,6"
679682

680683
# na_filter=True --> missing value becomes NaN.
@@ -798,7 +801,15 @@ def test_bool_and_nan_to_int(all_parsers):
798801
True
799802
False
800803
"""
801-
with pytest.raises(ValueError, match="convert|NoneType"):
804+
msg = (
805+
"cannot safely convert passed user dtype of int64 for "
806+
"<class 'numpy.bool'> dtyped data in column 0 due to NA values"
807+
)
808+
if parser.engine == "python":
809+
msg = "Unable to convert column 0 to type int64"
810+
elif parser.engine == "pyarrow":
811+
msg = r"cannot convert NA to integer"
812+
with pytest.raises(ValueError, match=msg):
802813
parser.read_csv(StringIO(data), dtype="int")
803814

804815

0 commit comments

Comments
 (0)