diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index e9d2dba75..d784ab926 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -45,7 +45,13 @@ SessionContext, SQLOptions, ) -from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions +from .dataframe import ( + DataFrame, + DataFrameWriteOptions, + InsertOp, + ParquetColumnOptions, + ParquetWriterOptions, +) from .dataframe_formatter import configure_formatter from .expr import ( Expr, @@ -75,9 +81,11 @@ "Config", "DFSchema", "DataFrame", + "DataFrameWriteOptions", "Database", "ExecutionPlan", "Expr", + "InsertOp", "LogicalPlan", "ParquetColumnOptions", "ParquetWriterOptions", diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c1b649e33..7a662387f 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -39,10 +39,13 @@ from typing_extensions import deprecated # Python 3.12 from datafusion._internal import DataFrame as DataFrameInternal +from datafusion._internal import DataFrameWriteOptions as DataFrameWriteOptionsInternal +from datafusion._internal import InsertOp as InsertOpInternal from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal from datafusion.expr import ( Expr, + SortExpr, SortKey, ensure_expr, ensure_expr_list, @@ -925,14 +928,23 @@ def except_all(self, other: DataFrame) -> DataFrame: """ return DataFrame(self.df.except_all(other.df)) - def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None: + def write_csv( + self, + path: str | pathlib.Path, + with_header: bool = False, + write_options: DataFrameWriteOptions | None = None, + ) -> None: """Execute the :py:class:`DataFrame` and write the results to a CSV file. Args: path: Path of the CSV file to write. with_header: If true, output the CSV header row. + write_options: Options that impact how the DataFrame is written. """ - self.df.write_csv(str(path), with_header) + raw_write_options = ( + write_options._raw_write_options if write_options is not None else None + ) + self.df.write_csv(str(path), with_header, raw_write_options) @overload def write_parquet( @@ -940,6 +952,7 @@ def write_parquet( path: str | pathlib.Path, compression: str, compression_level: int | None = None, + write_options: DataFrameWriteOptions | None = None, ) -> None: ... @overload @@ -948,6 +961,7 @@ def write_parquet( path: str | pathlib.Path, compression: Compression = Compression.ZSTD, compression_level: int | None = None, + write_options: DataFrameWriteOptions | None = None, ) -> None: ... @overload @@ -956,6 +970,7 @@ def write_parquet( path: str | pathlib.Path, compression: ParquetWriterOptions, compression_level: None = None, + write_options: DataFrameWriteOptions | None = None, ) -> None: ... def write_parquet( @@ -963,24 +978,30 @@ def write_parquet( path: str | pathlib.Path, compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD, compression_level: int | None = None, + write_options: DataFrameWriteOptions | None = None, ) -> None: """Execute the :py:class:`DataFrame` and write the results to a Parquet file. + Available compression types are: + + - "uncompressed": No compression. + - "snappy": Snappy compression. + - "gzip": Gzip compression. + - "brotli": Brotli compression. + - "lz4": LZ4 compression. + - "lz4_raw": LZ4_RAW compression. + - "zstd": Zstandard compression. + + LZO compression is not yet implemented in arrow-rs and is therefore + excluded. + Args: path: Path of the Parquet file to write. compression: Compression type to use. Default is "ZSTD". - Available compression types are: - - "uncompressed": No compression. - - "snappy": Snappy compression. - - "gzip": Gzip compression. - - "brotli": Brotli compression. - - "lz4": LZ4 compression. - - "lz4_raw": LZ4_RAW compression. - - "zstd": Zstandard compression. - Note: LZO is not yet implemented in arrow-rs and is therefore excluded. compression_level: Compression level to use. For ZSTD, the recommended range is 1 to 22, with the default being 4. Higher levels provide better compression but slower speed. + write_options: Options that impact how the DataFrame is written. """ if isinstance(compression, ParquetWriterOptions): if compression_level is not None: @@ -998,10 +1019,21 @@ def write_parquet( ): compression_level = compression.get_default_level() - self.df.write_parquet(str(path), compression.value, compression_level) + raw_write_options = ( + write_options._raw_write_options if write_options is not None else None + ) + self.df.write_parquet( + str(path), + compression.value, + compression_level, + raw_write_options, + ) def write_parquet_with_options( - self, path: str | pathlib.Path, options: ParquetWriterOptions + self, + path: str | pathlib.Path, + options: ParquetWriterOptions, + write_options: DataFrameWriteOptions | None = None, ) -> None: """Execute the :py:class:`DataFrame` and write the results to a Parquet file. @@ -1010,6 +1042,7 @@ def write_parquet_with_options( Args: path: Path of the Parquet file to write. options: Sets the writer parquet options (see `ParquetWriterOptions`). + write_options: Options that impact how the DataFrame is written. """ options_internal = ParquetWriterOptionsInternal( options.data_pagesize_limit, @@ -1046,19 +1079,45 @@ def write_parquet_with_options( bloom_filter_ndv=opts.bloom_filter_ndv, ) + raw_write_options = ( + write_options._raw_write_options if write_options is not None else None + ) self.df.write_parquet_with_options( str(path), options_internal, column_specific_options_internal, + raw_write_options, ) - def write_json(self, path: str | pathlib.Path) -> None: + def write_json( + self, + path: str | pathlib.Path, + write_options: DataFrameWriteOptions | None = None, + ) -> None: """Execute the :py:class:`DataFrame` and write the results to a JSON file. Args: path: Path of the JSON file to write. + write_options: Options that impact how the DataFrame is written. + """ + raw_write_options = ( + write_options._raw_write_options if write_options is not None else None + ) + self.df.write_json(str(path), write_options=raw_write_options) + + def write_table( + self, table_name: str, write_options: DataFrameWriteOptions | None = None + ) -> None: + """Execute the :py:class:`DataFrame` and write the results to a table. + + The table must be registered with the session to perform this operation. + Not all table providers support writing operations. See the individual + implementations for details. """ - self.df.write_json(str(path)) + raw_write_options = ( + write_options._raw_write_options if write_options is not None else None + ) + self.df.write_table(table_name, raw_write_options) def to_arrow_table(self) -> pa.Table: """Execute the :py:class:`DataFrame` and convert it into an Arrow Table. @@ -1206,3 +1265,49 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> DataFrame: - For columns not in subset, the original column is kept unchanged """ return DataFrame(self.df.fill_null(value, subset)) + + +class InsertOp(Enum): + """Insert operation mode. + + These modes are used by the table writing feature to define how record + batches should be written to a table. + """ + + APPEND = InsertOpInternal.APPEND + """Appends new rows to the existing table without modifying any existing rows.""" + + REPLACE = InsertOpInternal.REPLACE + """Replace existing rows that collide with the inserted rows. + + Replacement is typically based on a unique key or primary key. + """ + + OVERWRITE = InsertOpInternal.OVERWRITE + """Overwrites all existing rows in the table with the new rows.""" + + +class DataFrameWriteOptions: + """Writer options for DataFrame. + + There is no guarantee the table provider supports all writer options. + See the individual implementation and documentation for details. + """ + + def __init__( + self, + insert_operation: InsertOp | None = None, + single_file_output: bool = False, + partition_by: str | Sequence[str] | None = None, + sort_by: Expr | SortExpr | Sequence[Expr] | Sequence[SortExpr] | None = None, + ) -> None: + """Instantiate writer options for DataFrame.""" + if isinstance(partition_by, str): + partition_by = [partition_by] + + sort_by_raw = sort_list_to_raw_sort_list(sort_by) + insert_op = insert_operation.value if insert_operation is not None else None + + self._raw_write_options = DataFrameWriteOptionsInternal( + insert_op, single_file_output, partition_by, sort_by_raw + ) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index eb686dd19..cd85221c5 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -16,6 +16,7 @@ # under the License. import ctypes import datetime +import itertools import os import re import threading @@ -27,6 +28,7 @@ import pytest from datafusion import ( DataFrame, + InsertOp, ParquetColumnOptions, ParquetWriterOptions, SessionContext, @@ -40,6 +42,7 @@ from datafusion import ( functions as f, ) +from datafusion.dataframe import DataFrameWriteOptions from datafusion.dataframe_formatter import ( DataFrameHtmlFormatter, configure_formatter, @@ -58,9 +61,7 @@ def ctx(): @pytest.fixture -def df(): - ctx = SessionContext() - +def df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], @@ -1830,6 +1831,69 @@ def test_write_csv(ctx, df, tmp_path, path_to_str): assert result == expected +def generate_test_write_params() -> list[tuple]: + # Overwrite and Replace are not implemented for many table writers + insert_ops = [InsertOp.APPEND, None] + sort_by_cases = [ + (None, [1, 2, 3], "unsorted"), + (column("c"), [2, 1, 3], "single_column_expr"), + (column("a").sort(ascending=False), [3, 2, 1], "single_sort_expr"), + ([column("c"), column("b")], [2, 1, 3], "list_col_expr"), + ( + [column("c").sort(ascending=False), column("b").sort(ascending=False)], + [3, 1, 2], + "list_sort_expr", + ), + ] + + formats = ["csv", "json", "parquet", "table"] + + return [ + pytest.param( + output_format, + insert_op, + sort_by, + expected_a, + id=f"{output_format}_{test_id}", + ) + for output_format, insert_op, ( + sort_by, + expected_a, + test_id, + ) in itertools.product(formats, insert_ops, sort_by_cases) + ] + + +@pytest.mark.parametrize( + ("output_format", "insert_op", "sort_by", "expected_a"), + generate_test_write_params(), +) +def test_write_files_with_options( + ctx, df, tmp_path, output_format, insert_op, sort_by, expected_a +) -> None: + write_options = DataFrameWriteOptions(insert_operation=insert_op, sort_by=sort_by) + + if output_format == "csv": + df.write_csv(tmp_path, with_header=True, write_options=write_options) + ctx.register_csv("test_table", tmp_path) + elif output_format == "json": + df.write_json(tmp_path, write_options=write_options) + ctx.register_json("test_table", tmp_path) + elif output_format == "parquet": + df.write_parquet(tmp_path, write_options=write_options) + ctx.register_parquet("test_table", tmp_path) + elif output_format == "table": + batch = pa.RecordBatch.from_arrays([[], [], []], schema=df.schema()) + ctx.register_record_batches("test_table", [[batch]]) + ctx.table("test_table").show() + df.write_table("test_table", write_options=write_options) + + result = ctx.table("test_table").to_pydict()["a"] + ctx.table("test_table").show() + + assert result == expected_a + + @pytest.mark.parametrize("path_to_str", [True, False]) def test_write_json(ctx, df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -2322,6 +2386,25 @@ def test_write_parquet_options_error(df, tmp_path): df.write_parquet(str(tmp_path), options, compression_level=1) +def test_write_table(ctx, df): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3])], + names=["a"], + ) + + ctx.register_record_batches("t", [[batch]]) + + df = ctx.table("t").with_column("a", column("a") * literal(-1)) + + ctx.table("t").show() + + df.write_table("t") + result = ctx.table("t").sort(column("a")).collect()[0][0].to_pylist() + expected = [-3, -2, -1, 1, 2, 3] + + assert result == expected + + def test_dataframe_export(df) -> None: # Guarantees that we have the canonical implementation # reading our dataframe export diff --git a/src/dataframe.rs b/src/dataframe.rs index 5882acf76..62dfe7209 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -34,6 +34,8 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; +use datafusion::logical_expr::dml::InsertOp; +use datafusion::logical_expr::SortExpr; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; use datafusion_ffi::table_provider::FFI_TableProvider; @@ -742,18 +744,27 @@ impl PyDataFrame { } /// Write a `DataFrame` to a CSV file. - fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> { + fn write_csv( + &self, + py: Python, + path: &str, + with_header: bool, + write_options: Option, + ) -> PyDataFusionResult<()> { let csv_options = CsvOptions { has_header: Some(with_header), ..Default::default() }; + let write_options = write_options + .map(DataFrameWriteOptions::from) + .unwrap_or_default(); + wait_for_future( py, - self.df.as_ref().clone().write_csv( - path, - DataFrameWriteOptions::new(), - Some(csv_options), - ), + self.df + .as_ref() + .clone() + .write_csv(path, write_options, Some(csv_options)), )??; Ok(()) } @@ -762,13 +773,15 @@ impl PyDataFrame { #[pyo3(signature = ( path, compression="zstd", - compression_level=None + compression_level=None, + write_options=None, ))] fn write_parquet( &self, path: &str, compression: &str, compression_level: Option, + write_options: Option, py: Python, ) -> PyDataFusionResult<()> { fn verify_compression_level(cl: Option) -> Result { @@ -807,14 +820,16 @@ impl PyDataFrame { let mut options = TableParquetOptions::default(); options.global.compression = Some(compression_string); + let write_options = write_options + .map(DataFrameWriteOptions::from) + .unwrap_or_default(); wait_for_future( py, - self.df.as_ref().clone().write_parquet( - path, - DataFrameWriteOptions::new(), - Option::from(options), - ), + self.df + .as_ref() + .clone() + .write_parquet(path, write_options, Option::from(options)), )??; Ok(()) } @@ -825,6 +840,7 @@ impl PyDataFrame { path: &str, options: PyParquetWriterOptions, column_specific_options: HashMap, + write_options: Option, py: Python, ) -> PyDataFusionResult<()> { let table_options = TableParquetOptions { @@ -835,12 +851,14 @@ impl PyDataFrame { .collect(), ..Default::default() }; - + let write_options = write_options + .map(DataFrameWriteOptions::from) + .unwrap_or_default(); wait_for_future( py, self.df.as_ref().clone().write_parquet( path, - DataFrameWriteOptions::new(), + write_options, Option::from(table_options), ), )??; @@ -848,13 +866,40 @@ impl PyDataFrame { } /// Executes a query and writes the results to a partitioned JSON file. - fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> { + fn write_json( + &self, + path: &str, + py: Python, + write_options: Option, + ) -> PyDataFusionResult<()> { + let write_options = write_options + .map(DataFrameWriteOptions::from) + .unwrap_or_default(); + wait_for_future( + py, + self.df + .as_ref() + .clone() + .write_json(path, write_options, None), + )??; + Ok(()) + } + + fn write_table( + &self, + py: Python, + table_name: &str, + write_options: Option, + ) -> PyDataFusionResult<()> { + let write_options = write_options + .map(DataFrameWriteOptions::from) + .unwrap_or_default(); wait_for_future( py, self.df .as_ref() .clone() - .write_json(path, DataFrameWriteOptions::new(), None), + .write_table(table_name, write_options), )??; Ok(()) } @@ -993,6 +1038,67 @@ impl PyDataFrame { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(frozen, eq, eq_int, name = "InsertOp", module = "datafusion")] +pub enum PyInsertOp { + APPEND, + REPLACE, + OVERWRITE, +} + +impl From for InsertOp { + fn from(value: PyInsertOp) -> Self { + match value { + PyInsertOp::APPEND => InsertOp::Append, + PyInsertOp::REPLACE => InsertOp::Replace, + PyInsertOp::OVERWRITE => InsertOp::Overwrite, + } + } +} + +#[derive(Debug, Clone)] +#[pyclass(frozen, name = "DataFrameWriteOptions", module = "datafusion")] +pub struct PyDataFrameWriteOptions { + insert_operation: InsertOp, + single_file_output: bool, + partition_by: Vec, + sort_by: Vec, +} + +impl From for DataFrameWriteOptions { + fn from(value: PyDataFrameWriteOptions) -> Self { + DataFrameWriteOptions::new() + .with_insert_operation(value.insert_operation) + .with_single_file_output(value.single_file_output) + .with_partition_by(value.partition_by) + .with_sort_by(value.sort_by) + } +} + +#[pymethods] +impl PyDataFrameWriteOptions { + #[new] + fn new( + insert_operation: Option, + single_file_output: bool, + partition_by: Option>, + sort_by: Option>, + ) -> Self { + let insert_operation = insert_operation.map(Into::into).unwrap_or(InsertOp::Append); + let sort_by = sort_by + .unwrap_or_default() + .into_iter() + .map(Into::into) + .collect(); + Self { + insert_operation, + single_file_output, + partition_by: partition_by.unwrap_or_default(), + sort_by, + } + } +} + /// Print DataFrame fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> { // Get string representation of record batches diff --git a/src/lib.rs b/src/lib.rs index 29d3f41da..661cbd658 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,6 +86,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;