Skip to content

Commit c7d2bfb

Browse files
authored
Fix use_arrow_dtype parameter for read_parquet (#2698)
1 parent dab1fa5 commit c7d2bfb

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

mars/dataframe/datasource/read_parquet.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@
4848
from ...utils import is_object_dtype
4949
from ..arrays import ArrowStringDtype
5050
from ..operands import OutputType
51-
from ..utils import parse_index, to_arrow_dtypes, contain_arrow_dtype
51+
from ..utils import (
52+
parse_index,
53+
to_arrow_dtypes,
54+
contain_arrow_dtype,
55+
arrow_table_to_pandas_dataframe,
56+
)
5257
from .core import (
5358
IncrementalIndexDatasource,
5459
ColumnPruneSupportedDataSourceMixin,
@@ -351,7 +356,7 @@ def _execute_partitioned(cls, ctx, op: "DataFrameReadParquet"):
351356
table = piece.read(partitions=partitions)
352357
if op.nrows is not None:
353358
table = table.slice(0, op.nrows)
354-
ctx[out.key] = table.to_pandas()
359+
ctx[out.key] = arrow_table_to_pandas_dataframe(table, op.use_arrow_dtype)
355360

356361
@classmethod
357362
def execute(cls, ctx, op: "DataFrameReadParquet"):
@@ -500,10 +505,10 @@ def read_parquet(
500505
if columns:
501506
dtypes = dtypes[columns]
502507

503-
if use_arrow_dtype is None:
504-
use_arrow_dtype = options.dataframe.use_arrow_dtype
505-
if use_arrow_dtype:
506-
dtypes = to_arrow_dtypes(dtypes)
508+
if use_arrow_dtype is None:
509+
use_arrow_dtype = options.dataframe.use_arrow_dtype
510+
if use_arrow_dtype:
511+
dtypes = to_arrow_dtypes(dtypes)
507512

508513
index_value = parse_index(pd.RangeIndex(-1))
509514
columns_value = parse_index(dtypes.index, store_data=True)

mars/dataframe/datasource/tests/test_datasource_execution.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,12 @@ def test_read_parquet_arrow(setup):
10501050
r = mdf.execute().fetch()
10511051
pd.testing.assert_frame_equal(df, r.sort_values("a").reset_index(drop=True))
10521052

1053+
# test `use_arrow_dtype=True`
1054+
mdf = md.read_parquet(f"{tempdir}/*.parquet", use_arrow_dtype=True)
1055+
result = mdf.execute().fetch()
1056+
assert isinstance(mdf.dtypes.iloc[1], md.ArrowStringDtype)
1057+
assert isinstance(result.dtypes.iloc[1], md.ArrowStringDtype)
1058+
10531059
mdf = md.read_parquet(
10541060
f"{tempdir}/*.parquet",
10551061
groups_as_chunks=True,
@@ -1058,6 +1064,23 @@ def test_read_parquet_arrow(setup):
10581064
r = mdf.execute().fetch()
10591065
pd.testing.assert_frame_equal(df, r.sort_values("a").reset_index(drop=True))
10601066

1067+
# test partitioned
1068+
with tempfile.TemporaryDirectory() as tempdir:
1069+
df = pd.DataFrame(
1070+
{
1071+
"a": np.random.rand(300),
1072+
"b": [f"s{i}" for i in range(300)],
1073+
"c": np.random.choice(["a", "b", "c"], (300,)),
1074+
}
1075+
)
1076+
df.to_parquet(tempdir, partition_cols=["c"])
1077+
mdf = md.read_parquet(tempdir)
1078+
r = mdf.execute().fetch().astype(df.dtypes)
1079+
pd.testing.assert_frame_equal(
1080+
df.sort_values("a").reset_index(drop=True),
1081+
r.sort_values("a").reset_index(drop=True),
1082+
)
1083+
10611084

10621085
@pytest.mark.skipif(fastparquet is None, reason="fastparquet not installed")
10631086
def test_read_parquet_fast_parquet(setup):

0 commit comments

Comments
 (0)