Skip to content

Commit 835edf6

Browse files
committed
fix: pass dtypes to read_json with pyarrow engine
1 parent fe494c9 commit 835edf6

File tree

3 files changed

+74
-22
lines changed

3 files changed

+74
-22
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ I/O
732732
- Bug in :meth:`read_stata` where the missing code for double was not recognised for format versions 105 and prior (:issue:`58149`)
733733
- Bug in :meth:`set_option` where setting the pandas option ``display.html.use_mathjax`` to ``False`` has no effect (:issue:`59884`)
734734
- Bug in :meth:`to_excel` where :class:`MultiIndex` columns would be merged to a single row when ``merge_cells=False`` is passed (:issue:`60274`)
735+
- Bug in :meth:`read_json` ignoring the given ``dtype`` when ``engine="pyarrow"`` (:issue:`59516`)
735736

736737
Period
737738
^^^^^^

pandas/io/json/_json.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pandas.core.dtypes.common import (
3333
ensure_str,
3434
is_string_dtype,
35+
pandas_dtype,
3536
)
3637
from pandas.core.dtypes.dtypes import PeriodDtype
3738

@@ -43,6 +44,7 @@
4344
isna,
4445
notna,
4546
to_datetime,
47+
ArrowDtype,
4648
)
4749
from pandas.core.reshape.concat import concat
4850
from pandas.core.shared_docs import _shared_docs
@@ -942,29 +944,56 @@ def read(self) -> DataFrame | Series:
942944
obj: DataFrame | Series
943945
with self:
944946
if self.engine == "pyarrow":
945-
pyarrow_json = import_optional_dependency("pyarrow.json")
946-
pa_table = pyarrow_json.read_json(self.data)
947-
return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend)
947+
obj = self._read_pyarrow()
948948
elif self.engine == "ujson":
949-
if self.lines:
950-
if self.chunksize:
951-
obj = concat(self)
952-
elif self.nrows:
953-
lines = list(islice(self.data, self.nrows))
954-
lines_json = self._combine_lines(lines)
955-
obj = self._get_object_parser(lines_json)
956-
else:
957-
data = ensure_str(self.data)
958-
data_lines = data.split("\n")
959-
obj = self._get_object_parser(self._combine_lines(data_lines))
960-
else:
961-
obj = self._get_object_parser(self.data)
962-
if self.dtype_backend is not lib.no_default:
963-
return obj.convert_dtypes(
964-
infer_objects=False, dtype_backend=self.dtype_backend
965-
)
966-
else:
967-
return obj
949+
obj = self._read_ujson()
950+
951+
return obj
952+
953+
def _read_pyarrow(self) -> DataFrame:
954+
"""
955+
Read JSON using the pyarrow engine.
956+
"""
957+
pyarrow_json = import_optional_dependency("pyarrow.json")
958+
options = None
959+
960+
if isinstance(self.dtype, dict):
961+
pa = import_optional_dependency("pyarrow")
962+
fields = [
963+
(field, pandas_dtype(dtype).pyarrow_dtype)
964+
for field, dtype in self.dtype.items()
965+
if isinstance(pandas_dtype(dtype), ArrowDtype)
966+
]
967+
968+
schema = pa.schema(fields)
969+
options = pyarrow_json.ParseOptions(explicit_schema=schema)
970+
971+
pa_table = pyarrow_json.read_json(self.data, parse_options=options)
972+
return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend)
973+
974+
def _read_ujson(self) -> DataFrame | Series:
975+
"""
976+
Read JSON using the ujson engine.
977+
"""
978+
if self.lines:
979+
if self.chunksize:
980+
obj = concat(self)
981+
elif self.nrows:
982+
lines = list(islice(self.data, self.nrows))
983+
lines_json = self._combine_lines(lines)
984+
obj = self._get_object_parser(lines_json)
985+
else:
986+
data = ensure_str(self.data)
987+
data_lines = data.split("\n")
988+
obj = self._get_object_parser(self._combine_lines(data_lines))
989+
else:
990+
obj = self._get_object_parser(self.data)
991+
if self.dtype_backend is not lib.no_default:
992+
return obj.convert_dtypes(
993+
infer_objects=False, dtype_backend=self.dtype_backend
994+
)
995+
else:
996+
return obj
968997

969998
def _get_object_parser(self, json: str) -> DataFrame | Series:
970999
"""

pandas/tests/io/json/test_pandas.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,28 @@ def test_read_json_dtype_backend(
21832183
# string_storage setting -> ignore that for checking the result
21842184
tm.assert_frame_equal(result, expected, check_column_type=False)
21852185

2186+
@td.skip_if_no("pyarrow") # type: ignore
2187+
def test_read_json_pyarrow_with_dtype(self, datapath):
2188+
dtype = {"a": "int32[pyarrow]", "b": "int64[pyarrow]"}
2189+
2190+
df = read_json(
2191+
datapath("io", "json", "data", "line_delimited.json"),
2192+
dtype=dtype,
2193+
lines=True,
2194+
engine="pyarrow",
2195+
dtype_backend="pyarrow",
2196+
)
2197+
2198+
result = df.dtypes
2199+
expected = Series(
2200+
[
2201+
pd.ArrowDtype.construct_from_string("int32[pyarrow]"),
2202+
pd.ArrowDtype.construct_from_string("int64[pyarrow]"),
2203+
],
2204+
index=["a", "b"],
2205+
)
2206+
tm.assert_series_equal(result, expected)
2207+
21862208
@pytest.mark.parametrize("orient", ["split", "records", "index"])
21872209
def test_read_json_nullable_series(self, string_storage, dtype_backend, orient):
21882210
# GH#50750

0 commit comments

Comments
 (0)