diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 26c3d2e22..d51ba2287 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,8 +19,11 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Protocol +import pyarrow as pa + try: from warnings import deprecated # Python 3.13+ except ImportError: @@ -42,7 +45,6 @@ import pandas as pd import polars as pl - import pyarrow as pa from datafusion.plan import ExecutionPlan, LogicalPlan @@ -535,7 +537,7 @@ def register_listing_table( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", schema: pa.Schema | None = None, file_sort_order: list[list[Expr | SortExpr]] | None = None, @@ -556,6 +558,7 @@ def register_listing_table( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) file_sort_order_raw = ( [sort_list_to_raw_sort_list(f) for f in file_sort_order] if file_sort_order is not None @@ -774,7 +777,7 @@ def register_parquet( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -802,6 +805,7 @@ def register_parquet( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_parquet( name, str(path), @@ -865,7 +869,7 @@ def register_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> None: """Register a JSON file as a table. @@ -886,6 +890,7 @@ def register_json( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_json( name, str(path), @@ -902,7 +907,7 @@ def register_avro( path: str | pathlib.Path, schema: pa.Schema | None = None, file_extension: str = ".avro", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, ) -> None: """Register an Avro file as a table. @@ -918,6 +923,7 @@ def register_avro( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_avro( name, str(path), schema, file_extension, table_partition_cols ) @@ -977,7 +983,7 @@ def read_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a line-delimited JSON data source. @@ -997,6 +1003,7 @@ def read_json( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) return DataFrame( self.ctx.read_json( str(path), @@ -1016,7 +1023,7 @@ def read_csv( delimiter: str = ",", schema_infer_max_records: int = 1000, file_extension: str = ".csv", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a CSV data source. @@ -1041,6 +1048,7 @@ def read_csv( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) path = [str(p) for p in path] if isinstance(path, list) else str(path) @@ -1060,7 +1068,7 @@ def read_csv( def read_parquet( self, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -1089,6 +1097,7 @@ def read_parquet( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) file_sort_order = ( [sort_list_to_raw_sort_list(f) for f in file_sort_order] if file_sort_order is not None @@ -1110,7 +1119,7 @@ def read_avro( self, path: str | pathlib.Path, schema: pa.Schema | None = None, - file_partition_cols: list[tuple[str, str]] | None = None, + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".avro", ) -> DataFrame: """Create a :py:class:`DataFrame` for reading Avro data source. @@ -1126,6 +1135,7 @@ def read_avro( """ if file_partition_cols is None: file_partition_cols = [] + file_partition_cols = self._convert_table_partition_cols(file_partition_cols) return DataFrame( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) @@ -1142,3 +1152,41 @@ def read_table(self, table: Table) -> DataFrame: def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions)) + + @staticmethod + def _convert_table_partition_cols( + table_partition_cols: list[tuple[str, str | pa.DataType]], + ) -> list[tuple[str, pa.DataType]]: + warn = False + converted_table_partition_cols = [] + + for col, data_type in table_partition_cols: + if isinstance(data_type, str): + warn = True + if data_type == "string": + converted_data_type = pa.string() + elif data_type == "int": + converted_data_type = pa.int32() + else: + message = ( + f"Unsupported literal data type '{data_type}' for partition " + "column. Supported types are 'string' and 'int'" + ) + raise ValueError(message) + else: + converted_data_type = data_type + + converted_table_partition_cols.append((col, converted_data_type)) + + if warn: + message = ( + "using literals for table_partition_cols data types is deprecated," + "use pyarrow types instead" + ) + warnings.warn( + message, + category=DeprecationWarning, + stacklevel=2, + ) + + return converted_table_partition_cols diff --git a/python/datafusion/io.py b/python/datafusion/io.py index ef5ebf96f..551e20a6f 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -34,7 +34,7 @@ def read_parquet( path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -83,7 +83,7 @@ def read_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a line-delimited JSON data source. @@ -124,7 +124,7 @@ def read_csv( delimiter: str = ",", schema_infer_max_records: int = 1000, file_extension: str = ".csv", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a CSV data source. @@ -171,7 +171,7 @@ def read_csv( def read_avro( path: str | pathlib.Path, schema: pa.Schema | None = None, - file_partition_cols: list[tuple[str, str]] | None = None, + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".avro", ) -> DataFrame: """Create a :py:class:`DataFrame` for reading Avro data source. diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index b6348e3a0..41cee4ef3 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -157,8 +157,10 @@ def test_register_parquet(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} -@pytest.mark.parametrize("path_to_str", [True, False]) -def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): +@pytest.mark.parametrize( + ("path_to_str", "legacy_data_type"), [(True, False), (False, False), (False, True)] +) +def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_type): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) (dir_root / "grp=a").mkdir(exist_ok=False) @@ -177,10 +179,12 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = str(dir_root) if path_to_str else dir_root + partition_data_type = "string" if legacy_data_type else pa.string() + ctx.register_parquet( "datapp", dir_root, - table_partition_cols=[("grp", "string")], + table_partition_cols=[("grp", partition_data_type)], parquet_pruning=True, file_extension=".parquet", ) @@ -488,9 +492,9 @@ def test_register_listing_table( ): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) - (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True) - (dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True) - (dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True) table = pa.Table.from_arrays( [ @@ -501,13 +505,13 @@ def test_register_listing_table( names=["int", "str", "float"], ) pa.parquet.write_table( - table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet" + table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet" ) pa.parquet.write_table( - table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet" + table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet" ) pa.parquet.write_table( - table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet" + table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet" ) dir_root = f"file://{dir_root}/" if path_to_str else dir_root @@ -515,7 +519,7 @@ def test_register_listing_table( ctx.register_listing_table( "my_table", dir_root, - table_partition_cols=[("grp", "string"), ("date_id", "int")], + table_partition_cols=[("grp", pa.string()), ("date", pa.date64())], file_extension=".parquet", schema=table.schema if pass_schema else None, file_sort_order=file_sort_order, @@ -531,7 +535,7 @@ def test_register_listing_table( assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2} result = ctx.sql( - "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501 + "SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' GROUP BY grp" # noqa: E501 ).collect() result = pa.Table.from_batches(result) diff --git a/src/context.rs b/src/context.rs index cc3d8e8e9..05b532d50 100644 --- a/src/context.rs +++ b/src/context.rs @@ -353,7 +353,7 @@ impl PySessionContext { &mut self, name: &str, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_extension: &str, schema: Option>, file_sort_order: Option>>, @@ -361,7 +361,12 @@ impl PySessionContext { ) -> PyDataFusionResult<()> { let options = ListingOptions::new(Arc::new(ParquetFormat::new())) .with_file_extension(file_extension) - .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .with_table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .with_file_sort_order( file_sort_order .unwrap_or_default() @@ -629,7 +634,7 @@ impl PySessionContext { &mut self, name: &str, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, parquet_pruning: bool, file_extension: &str, skip_metadata: bool, @@ -638,7 +643,12 @@ impl PySessionContext { py: Python, ) -> PyDataFusionResult<()> { let mut options = ParquetReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .parquet_pruning(parquet_pruning) .skip_metadata(skip_metadata); options.file_extension = file_extension; @@ -718,7 +728,7 @@ impl PySessionContext { schema: Option>, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult<()> { @@ -728,7 +738,12 @@ impl PySessionContext { let mut options = NdJsonReadOptions::default() .file_compression_type(parse_file_compression_type(file_compression_type)?) - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.schema_infer_max_records = schema_infer_max_records; options.file_extension = file_extension; options.schema = schema.as_ref().map(|x| &x.0); @@ -751,15 +766,19 @@ impl PySessionContext { path: PathBuf, schema: Option>, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, py: Python, ) -> PyDataFusionResult<()> { let path = path .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; - let mut options = AvroReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + let mut options = AvroReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.file_extension = file_extension; options.schema = schema.as_ref().map(|x| &x.0); @@ -850,7 +869,7 @@ impl PySessionContext { schema: Option>, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult { @@ -858,7 +877,12 @@ impl PySessionContext { .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; let mut options = NdJsonReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema_infer_max_records = schema_infer_max_records; options.file_extension = file_extension; @@ -891,7 +915,7 @@ impl PySessionContext { delimiter: &str, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult { @@ -907,7 +931,12 @@ impl PySessionContext { .delimiter(delimiter[0]) .schema_infer_max_records(schema_infer_max_records) .file_extension(file_extension) - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema = schema.as_ref().map(|x| &x.0); @@ -937,7 +966,7 @@ impl PySessionContext { pub fn read_parquet( &self, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, parquet_pruning: bool, file_extension: &str, skip_metadata: bool, @@ -946,7 +975,12 @@ impl PySessionContext { py: Python, ) -> PyDataFusionResult { let mut options = ParquetReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .parquet_pruning(parquet_pruning) .skip_metadata(skip_metadata); options.file_extension = file_extension; @@ -968,12 +1002,16 @@ impl PySessionContext { &self, path: &str, schema: Option>, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_extension: &str, py: Python, ) -> PyDataFusionResult { - let mut options = AvroReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + let mut options = AvroReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.file_extension = file_extension; let df = if let Some(schema) = schema { options.schema = Some(&schema.0); @@ -1072,21 +1110,6 @@ impl PySessionContext { } } -pub fn convert_table_partition_cols( - table_partition_cols: Vec<(String, String)>, -) -> PyDataFusionResult> { - table_partition_cols - .into_iter() - .map(|(name, ty)| match ty.as_str() { - "string" => Ok((name, DataType::Utf8)), - "int" => Ok((name, DataType::Int32)), - _ => Err(crate::errors::PyDataFusionError::Common(format!( - "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'" - ))), - }) - .collect::, _>>() -} - pub fn parse_file_compression_type( file_compression_type: Option, ) -> Result {