Skip to content

Commit 221328d

Browse files
committed
Use Matts idea
1 parent 03c6d00 commit 221328d

File tree

3 files changed

+96
-67
lines changed

3 files changed

+96
-67
lines changed

pandas/io/_util.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,23 @@
1616
)
1717
from pandas.compat._optional import import_optional_dependency
1818

19+
from pandas.core.dtypes.common import pandas_dtype
20+
1921
import pandas as pd
2022

2123
if TYPE_CHECKING:
22-
from collections.abc import Callable
24+
from collections.abc import (
25+
Callable,
26+
Hashable,
27+
Sequence,
28+
)
2329

2430
import pyarrow
2531

26-
from pandas._typing import DtypeBackend
32+
from pandas._typing import (
33+
DtypeArg,
34+
DtypeBackend,
35+
)
2736

2837

2938
def _arrow_dtype_mapping() -> dict:
@@ -64,6 +73,8 @@ def arrow_table_to_pandas(
6473
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
6574
null_to_int64: bool = False,
6675
to_pandas_kwargs: dict | None = None,
76+
dtype: DtypeArg | None = None,
77+
names: Sequence[Hashable] | None = None,
6778
) -> pd.DataFrame:
6879
pa = import_optional_dependency("pyarrow")
6980

@@ -82,12 +93,77 @@ def arrow_table_to_pandas(
8293
elif using_string_dtype():
8394
if pa_version_under19p0:
8495
types_mapper = _arrow_string_types_mapper()
96+
elif dtype is not None:
97+
# GH#56136 Avoid lossy conversion to float64
98+
# We'll convert to numpy below if
99+
types_mapper = {
100+
pa.int8(): pd.Int8Dtype(),
101+
pa.int16(): pd.Int16Dtype(),
102+
pa.int32(): pd.Int32Dtype(),
103+
pa.int64(): pd.Int64Dtype(),
104+
}.get
85105
else:
86106
types_mapper = None
87107
elif dtype_backend is lib.no_default or dtype_backend == "numpy":
88-
types_mapper = None
108+
if dtype is not None:
109+
# GH#56136 Avoid lossy conversion to float64
110+
# We'll convert to numpy below if
111+
types_mapper = {
112+
pa.int8(): pd.Int8Dtype(),
113+
pa.int16(): pd.Int16Dtype(),
114+
pa.int32(): pd.Int32Dtype(),
115+
pa.int64(): pd.Int64Dtype(),
116+
}.get
117+
else:
118+
types_mapper = None
89119
else:
90120
raise NotImplementedError
91121

92122
df = table.to_pandas(types_mapper=types_mapper, **to_pandas_kwargs)
123+
return _post_convert_dtypes(df, dtype_backend, dtype, names)
124+
125+
126+
def _post_convert_dtypes(
127+
df: pd.DataFrame,
128+
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault,
129+
dtype: DtypeArg | None,
130+
names: Sequence[Hashable] | None,
131+
) -> pd.DataFrame:
132+
if dtype is not None and (
133+
dtype_backend is lib.no_default or dtype_backend == "numpy"
134+
):
135+
# GH#56136 apply any user-provided dtype, and convert any IntegerDtype
136+
# columns the user didn't explicitly ask for.
137+
if isinstance(dtype, dict):
138+
if names is not None:
139+
df.columns = names
140+
141+
cmp_dtypes = {
142+
pd.Int8Dtype(),
143+
pd.Int16Dtype(),
144+
pd.Int32Dtype(),
145+
pd.Int64Dtype(),
146+
}
147+
for col in df.columns:
148+
if col not in dtype and df[col].dtype in cmp_dtypes:
149+
# Any key that the user didn't explicitly specify
150+
# that got converted to IntegerDtype now gets converted
151+
# to numpy dtype.
152+
dtype[col] = df[col].dtype.numpy_dtype
153+
154+
# Ignore non-existent columns from dtype mapping
155+
# like other parsers do
156+
dtype = {
157+
key: pandas_dtype(dtype[key]) for key in dtype if key in df.columns
158+
}
159+
160+
else:
161+
dtype = pandas_dtype(dtype)
162+
163+
try:
164+
df = df.astype(dtype)
165+
except TypeError as err:
166+
# GH#44901 reraise to keep api consistent
167+
raise ValueError(str(err)) from err
168+
93169
return df

pandas/io/parsers/arrow_parser_wrapper.py

Lines changed: 13 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33
from typing import TYPE_CHECKING
44
import warnings
55

6-
import numpy as np
7-
8-
from pandas._config import using_string_dtype
9-
106
from pandas._libs import lib
117
from pandas.compat._optional import import_optional_dependency
128
from pandas.errors import (
@@ -16,20 +12,16 @@
1612
from pandas.util._exceptions import find_stack_level
1713

1814
from pandas.core.dtypes.common import (
19-
is_string_dtype,
2015
pandas_dtype,
2116
)
22-
from pandas.core.dtypes.dtypes import (
23-
BaseMaskedDtype,
24-
)
2517
from pandas.core.dtypes.inference import is_integer
2618

27-
from pandas.core.arrays.string_ import StringDtype
28-
2919
from pandas.io._util import arrow_table_to_pandas
3020
from pandas.io.parsers.base_parser import ParserBase
3121

3222
if TYPE_CHECKING:
23+
import pyarrow as pa
24+
3325
from pandas._typing import ReadBuffer
3426

3527
from pandas import DataFrame
@@ -174,8 +166,8 @@ def _get_convert_options(self):
174166

175167
return convert_options
176168

177-
def _adjust_column_names(self, frame: DataFrame) -> tuple[DataFrame, bool]:
178-
num_cols = len(frame.columns)
169+
def _adjust_column_names(self, table: pa.Table) -> bool:
170+
num_cols = len(table.columns)
179171
multi_index_named = True
180172
if self.header is None:
181173
if self.names is None:
@@ -188,8 +180,7 @@ def _adjust_column_names(self, frame: DataFrame) -> tuple[DataFrame, bool]:
188180
columns_prefix = [str(x) for x in range(num_cols - len(self.names))]
189181
self.names = columns_prefix + self.names
190182
multi_index_named = False
191-
frame.columns = self.names
192-
return frame, multi_index_named
183+
return multi_index_named
193184

194185
def _finalize_index(self, frame: DataFrame, multi_index_named: bool) -> DataFrame:
195186
if self.index_col is not None:
@@ -312,13 +303,7 @@ def read(self) -> DataFrame:
312303

313304
table = table.cast(new_schema)
314305

315-
workaround = False
316-
pass_backend = dtype_backend
317-
if self.dtype is not None and dtype_backend != "pyarrow":
318-
# We pass dtype_backend="pyarrow" and subsequently cast
319-
# to avoid lossy conversion e.g. GH#56136
320-
workaround = True
321-
pass_backend = "numpy_nullable"
306+
multi_index_named = self._adjust_column_names(table)
322307

323308
with warnings.catch_warnings():
324309
warnings.filterwarnings(
@@ -327,49 +312,14 @@ def read(self) -> DataFrame:
327312
DeprecationWarning,
328313
)
329314
frame = arrow_table_to_pandas(
330-
table, dtype_backend=pass_backend, null_to_int64=True
315+
table,
316+
dtype_backend=dtype_backend,
317+
null_to_int64=True,
318+
dtype=self.dtype,
319+
names=self.names,
331320
)
332321

333-
frame, multi_index_named = self._adjust_column_names(frame)
334-
335-
if workaround and dtype_backend != "numpy_nullable":
336-
old_dtype = self.dtype
337-
if not isinstance(old_dtype, dict):
338-
# e.g. test_categorical_dtype_utf16
339-
old_dtype = dict.fromkeys(frame.columns, old_dtype)
340-
341-
# _finalize_pandas_output will call astype, but we need to make
342-
# sure all keys are populated appropriately.
343-
new_dtype = {}
344-
for key in frame.columns:
345-
ser = frame[key]
346-
if isinstance(ser.dtype, BaseMaskedDtype):
347-
new_dtype[key] = ser.dtype.numpy_dtype
348-
if (
349-
key in old_dtype
350-
and not using_string_dtype()
351-
and is_string_dtype(old_dtype[key])
352-
and not isinstance(old_dtype[key], StringDtype)
353-
and ser.array._hasna
354-
):
355-
# Cast to make sure we get "NaN" string instead of "NA"
356-
frame[key] = ser.astype(old_dtype[key])
357-
frame.loc[ser.isna(), key] = np.nan
358-
old_dtype[key] = object # Avoid re-casting
359-
elif isinstance(ser.dtype, StringDtype):
360-
# We cast here in case the user passed "category" in
361-
# order to get the correct dtype.categories.dtype
362-
# e.g. test_categorical_dtype_utf16
363-
if not using_string_dtype():
364-
sdt = np.dtype(object)
365-
frame[key] = ser.astype(sdt)
366-
frame.loc[ser.isna(), key] = np.nan
367-
else:
368-
sdt = StringDtype(na_value=np.nan) # type: ignore[assignment]
369-
frame[key] = frame[key].astype(sdt)
370-
new_dtype[key] = sdt
371-
372-
new_dtype.update(old_dtype)
373-
self.dtype = new_dtype
322+
if self.header is None:
323+
frame.columns = self.names
374324

375325
return self._finalize_pandas_output(frame, multi_index_named)

pandas/tests/io/parser/test_na_values.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,10 @@ def test_bool_and_nan_to_int(all_parsers):
808808
if parser.engine == "python":
809809
msg = "Unable to convert column 0 to type int(64|32)"
810810
elif parser.engine == "pyarrow":
811-
msg = r"cannot convert NA to integer"
811+
msg = (
812+
r"int\(\) argument must be a string, a bytes-like object or a "
813+
"real number, not 'NoneType"
814+
)
812815
with pytest.raises(ValueError, match=msg):
813816
parser.read_csv(StringIO(data), dtype="int")
814817

0 commit comments

Comments
 (0)