Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Protocol

import pyarrow as pa

try:
from warnings import deprecated # Python 3.13+
except ImportError:
Expand All @@ -42,7 +45,6 @@

import pandas as pd
import polars as pl
import pyarrow as pa

from datafusion.plan import ExecutionPlan, LogicalPlan

Expand Down Expand Up @@ -535,7 +537,7 @@ def register_listing_table(
self,
name: str,
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_extension: str = ".parquet",
schema: pa.Schema | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
Expand All @@ -556,6 +558,7 @@ def register_listing_table(
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
file_sort_order_raw = (
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
if file_sort_order is not None
Expand Down Expand Up @@ -774,7 +777,7 @@ def register_parquet(
self,
name: str,
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
Expand Down Expand Up @@ -802,6 +805,7 @@ def register_parquet(
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
self.ctx.register_parquet(
name,
str(path),
Expand Down Expand Up @@ -865,7 +869,7 @@ def register_json(
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> None:
"""Register a JSON file as a table.
Expand All @@ -886,6 +890,7 @@ def register_json(
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
self.ctx.register_json(
name,
str(path),
Expand All @@ -902,7 +907,7 @@ def register_avro(
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".avro",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
) -> None:
"""Register an Avro file as a table.

Expand All @@ -918,6 +923,7 @@ def register_avro(
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
self.ctx.register_avro(
name, str(path), schema, file_extension, table_partition_cols
)
Expand Down Expand Up @@ -977,7 +983,7 @@ def read_json(
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.
Expand All @@ -997,6 +1003,7 @@ def read_json(
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
return DataFrame(
self.ctx.read_json(
str(path),
Expand All @@ -1016,7 +1023,7 @@ def read_csv(
delimiter: str = ",",
schema_infer_max_records: int = 1000,
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a CSV data source.
Expand All @@ -1041,6 +1048,7 @@ def read_csv(
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)

path = [str(p) for p in path] if isinstance(path, list) else str(path)

Expand All @@ -1060,7 +1068,7 @@ def read_csv(
def read_parquet(
self,
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
Expand Down Expand Up @@ -1089,6 +1097,7 @@ def read_parquet(
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
file_sort_order = (
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
if file_sort_order is not None
Expand All @@ -1110,7 +1119,7 @@ def read_avro(
self,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_partition_cols: list[tuple[str, str]] | None = None,
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_extension: str = ".avro",
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading Avro data source.
Expand All @@ -1126,6 +1135,7 @@ def read_avro(
"""
if file_partition_cols is None:
file_partition_cols = []
file_partition_cols = self._convert_table_partition_cols(file_partition_cols)
return DataFrame(
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
)
Expand All @@ -1142,3 +1152,41 @@ def read_table(self, table: Table) -> DataFrame:
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
"""Execute the ``plan`` and return the results."""
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))

@staticmethod
def _convert_table_partition_cols(
table_partition_cols: list[tuple[str, str | pa.DataType]],
) -> list[tuple[str, pa.DataType]]:
warn = False
converted_table_partition_cols = []

for col, data_type in table_partition_cols:
if isinstance(data_type, str):
warn = True
if data_type == "string":
converted_data_type = pa.string()
elif data_type == "int":
converted_data_type = pa.int32()
else:
message = (
f"Unsupported literal data type '{data_type}' for partition "
"column. Supported types are 'string' and 'int'"
)
raise ValueError(message)
else:
converted_data_type = data_type

converted_table_partition_cols.append((col, converted_data_type))

if warn:
message = (
"using literals for table_partition_cols data types is deprecated,"
"use pyarrow types instead"
)
warnings.warn(
message,
category=DeprecationWarning,
stacklevel=2,
)

return converted_table_partition_cols
8 changes: 4 additions & 4 deletions python/datafusion/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def read_parquet(
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
Expand Down Expand Up @@ -83,7 +83,7 @@ def read_json(
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.
Expand Down Expand Up @@ -124,7 +124,7 @@ def read_csv(
delimiter: str = ",",
schema_infer_max_records: int = 1000,
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a CSV data source.
Expand Down Expand Up @@ -171,7 +171,7 @@ def read_csv(
def read_avro(
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_partition_cols: list[tuple[str, str]] | None = None,
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_extension: str = ".avro",
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading Avro data source.
Expand Down
26 changes: 15 additions & 11 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ def test_register_parquet(ctx, tmp_path):
assert result.to_pydict() == {"cnt": [100]}


@pytest.mark.parametrize("path_to_str", [True, False])
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
@pytest.mark.parametrize(
("path_to_str", "legacy_data_type"), [(True, False), (False, False), (False, True)]
)
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_type):
dir_root = tmp_path / "dataset_parquet_partitioned"
dir_root.mkdir(exist_ok=False)
(dir_root / "grp=a").mkdir(exist_ok=False)
Expand All @@ -177,10 +179,12 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):

dir_root = str(dir_root) if path_to_str else dir_root

partition_data_type = "string" if legacy_data_type else pa.string()

ctx.register_parquet(
"datapp",
dir_root,
table_partition_cols=[("grp", "string")],
table_partition_cols=[("grp", partition_data_type)],
parquet_pruning=True,
file_extension=".parquet",
)
Expand Down Expand Up @@ -488,9 +492,9 @@ def test_register_listing_table(
):
dir_root = tmp_path / "dataset_parquet_partitioned"
dir_root.mkdir(exist_ok=False)
(dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True)

table = pa.Table.from_arrays(
[
Expand All @@ -501,21 +505,21 @@ def test_register_listing_table(
names=["int", "str", "float"],
)
pa.parquet.write_table(
table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet"
table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet"
)
pa.parquet.write_table(
table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet"
table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet"
)
pa.parquet.write_table(
table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet"
table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet"
)

dir_root = f"file://{dir_root}/" if path_to_str else dir_root

ctx.register_listing_table(
"my_table",
dir_root,
table_partition_cols=[("grp", "string"), ("date_id", "int")],
table_partition_cols=[("grp", pa.string()), ("date", pa.date64())],
file_extension=".parquet",
schema=table.schema if pass_schema else None,
file_sort_order=file_sort_order,
Expand All @@ -531,7 +535,7 @@ def test_register_listing_table(
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}

result = ctx.sql(
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' GROUP BY grp" # noqa: E501
).collect()
result = pa.Table.from_batches(result)

Expand Down
Loading