Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def register_listing_table(
self,
name: str,
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid a breaking change, we could make the type hint

table_partition_cols: list[tuple[str, pa.DataType]] | list[tuple[str, str]] | None = None

And then in the python code just below we can do type checking of the table_partition_cols and coerce to pyarrow data type. Alternatively we could overload the method signature and mark the old one as deprecated for a couple of releases. Either way we could avoid a breaking change without giving users the opportunity to upgrade at their schedule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, I opted for overloading and raising the deprecation warning

file_extension: str = ".parquet",
schema: pa.Schema | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
Expand Down Expand Up @@ -774,7 +774,7 @@ def register_parquet(
self,
name: str,
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
Expand Down Expand Up @@ -865,7 +865,7 @@ def register_json(
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> None:
"""Register a JSON file as a table.
Expand Down Expand Up @@ -902,7 +902,7 @@ def register_avro(
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".avro",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
) -> None:
"""Register an Avro file as a table.

Expand Down Expand Up @@ -977,7 +977,7 @@ def read_json(
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.
Expand Down Expand Up @@ -1016,7 +1016,7 @@ def read_csv(
delimiter: str = ",",
schema_infer_max_records: int = 1000,
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a CSV data source.
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def read_csv(
def read_parquet(
self,
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions python/datafusion/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def read_parquet(
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
Expand Down Expand Up @@ -83,7 +83,7 @@ def read_json(
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.
Expand Down Expand Up @@ -124,7 +124,7 @@ def read_csv(
delimiter: str = ",",
schema_infer_max_records: int = 1000,
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str]] | None = None,
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a CSV data source.
Expand Down
18 changes: 9 additions & 9 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
ctx.register_parquet(
"datapp",
dir_root,
table_partition_cols=[("grp", "string")],
table_partition_cols=[("grp", pa.string())],
parquet_pruning=True,
file_extension=".parquet",
)
Expand Down Expand Up @@ -488,9 +488,9 @@ def test_register_listing_table(
):
dir_root = tmp_path / "dataset_parquet_partitioned"
dir_root.mkdir(exist_ok=False)
(dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True)

table = pa.Table.from_arrays(
[
Expand All @@ -501,21 +501,21 @@ def test_register_listing_table(
names=["int", "str", "float"],
)
pa.parquet.write_table(
table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet"
table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet"
)
pa.parquet.write_table(
table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet"
table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet"
)
pa.parquet.write_table(
table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet"
table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet"
)

dir_root = f"file://{dir_root}/" if path_to_str else dir_root

ctx.register_listing_table(
"my_table",
dir_root,
table_partition_cols=[("grp", "string"), ("date_id", "int")],
table_partition_cols=[("grp", pa.string()), ("date", pa.date64())],
file_extension=".parquet",
schema=table.schema if pass_schema else None,
file_sort_order=file_sort_order,
Expand All @@ -531,7 +531,7 @@ def test_register_listing_table(
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}

result = ctx.sql(
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' GROUP BY grp" # noqa: E501
).collect()
result = pa.Table.from_batches(result)

Expand Down
89 changes: 56 additions & 33 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,20 @@ impl PySessionContext {
&mut self,
name: &str,
path: &str,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_extension: &str,
schema: Option<PyArrowType<Schema>>,
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
py: Python,
) -> PyDataFusionResult<()> {
let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
.with_file_extension(file_extension)
.with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.with_table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
)
.with_file_sort_order(
file_sort_order
.unwrap_or_default()
Expand Down Expand Up @@ -629,7 +634,7 @@ impl PySessionContext {
&mut self,
name: &str,
path: &str,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
parquet_pruning: bool,
file_extension: &str,
skip_metadata: bool,
Expand All @@ -638,7 +643,12 @@ impl PySessionContext {
py: Python,
) -> PyDataFusionResult<()> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
)
.parquet_pruning(parquet_pruning)
.skip_metadata(skip_metadata);
options.file_extension = file_extension;
Expand Down Expand Up @@ -718,7 +728,7 @@ impl PySessionContext {
schema: Option<PyArrowType<Schema>>,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_compression_type: Option<String>,
py: Python,
) -> PyDataFusionResult<()> {
Expand All @@ -728,7 +738,12 @@ impl PySessionContext {

let mut options = NdJsonReadOptions::default()
.file_compression_type(parse_file_compression_type(file_compression_type)?)
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
.table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.schema_infer_max_records = schema_infer_max_records;
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);
Expand All @@ -751,15 +766,19 @@ impl PySessionContext {
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;

let mut options = AvroReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
let mut options = AvroReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

Expand Down Expand Up @@ -850,15 +869,20 @@ impl PySessionContext {
schema: Option<PyArrowType<Schema>>,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_compression_type: Option<String>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
let mut options = NdJsonReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
)
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema_infer_max_records = schema_infer_max_records;
options.file_extension = file_extension;
Expand Down Expand Up @@ -891,7 +915,7 @@ impl PySessionContext {
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_compression_type: Option<String>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
Expand All @@ -907,7 +931,12 @@ impl PySessionContext {
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension)
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
)
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema = schema.as_ref().map(|x| &x.0);

Expand Down Expand Up @@ -937,7 +966,7 @@ impl PySessionContext {
pub fn read_parquet(
&self,
path: &str,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
parquet_pruning: bool,
file_extension: &str,
skip_metadata: bool,
Expand All @@ -946,7 +975,12 @@ impl PySessionContext {
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
)
.parquet_pruning(parquet_pruning)
.skip_metadata(skip_metadata);
options.file_extension = file_extension;
Expand All @@ -968,12 +1002,16 @@ impl PySessionContext {
&self,
path: &str,
schema: Option<PyArrowType<Schema>>,
table_partition_cols: Vec<(String, String)>,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_extension: &str,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
let mut options = AvroReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
let mut options = AvroReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
Expand Down Expand Up @@ -1072,21 +1110,6 @@ impl PySessionContext {
}
}

pub fn convert_table_partition_cols(
table_partition_cols: Vec<(String, String)>,
) -> PyDataFusionResult<Vec<(String, DataType)>> {
table_partition_cols
.into_iter()
.map(|(name, ty)| match ty.as_str() {
"string" => Ok((name, DataType::Utf8)),
"int" => Ok((name, DataType::Int32)),
_ => Err(crate::errors::PyDataFusionError::Common(format!(
"Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
))),
})
.collect::<Result<Vec<_>, _>>()
}

pub fn parse_file_compression_type(
file_compression_type: Option<String>,
) -> Result<FileCompressionType, PyErr> {
Expand Down