diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 0b38db924..f8aef0c91 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -21,7 +21,16 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, List, TYPE_CHECKING, Literal, overload +from typing import ( + Any, + Iterable, + List, + TYPE_CHECKING, + Literal, + overload, + Optional, + Union, +) from datafusion.record_batch import RecordBatchStream from typing_extensions import deprecated from datafusion.plan import LogicalPlan, ExecutionPlan @@ -35,6 +44,60 @@ from datafusion._internal import DataFrame as DataFrameInternal from datafusion.expr import Expr, SortExpr, sort_or_default +from enum import Enum + + +# excerpt from deltalake +# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 +class Compression(Enum): + """Enum representing the available compression types for Parquet files.""" + + UNCOMPRESSED = "uncompressed" + SNAPPY = "snappy" + GZIP = "gzip" + BROTLI = "brotli" + LZ4 = "lz4" + LZ0 = "lz0" + ZSTD = "zstd" + LZ4_RAW = "lz4_raw" + + @classmethod + def from_str(cls, value: str) -> "Compression": + """Convert a string to a Compression enum value. + + Args: + value: The string representation of the compression type. + + Returns: + The Compression enum lowercase value. + + Raises: + ValueError: If the string does not match any Compression enum value. + """ + try: + return cls(value.lower()) + except ValueError: + raise ValueError( + f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}" + ) + + def get_default_level(self) -> Optional[int]: + """Get the default compression level for the compression type. + + Returns: + The default compression level for the compression type. + """ + # GZIP, BROTLI default values from deltalake repo + # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 + # ZSTD default value from delta-rs + # https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223 + if self == Compression.GZIP: + return 6 + elif self == Compression.BROTLI: + return 1 + elif self == Compression.ZSTD: + return 4 + return None class DataFrame: @@ -620,17 +683,36 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None def write_parquet( self, path: str | pathlib.Path, - compression: str = "uncompressed", + compression: Union[str, Compression] = Compression.ZSTD, compression_level: int | None = None, ) -> None: """Execute the :py:class:`DataFrame` and write the results to a Parquet file. Args: path: Path of the Parquet file to write. - compression: Compression type to use. - compression_level: Compression level to use. - """ - self.df.write_parquet(str(path), compression, compression_level) + compression: Compression type to use. Default is "ZSTD". + Available compression types are: + - "uncompressed": No compression. + - "snappy": Snappy compression. + - "gzip": Gzip compression. + - "brotli": Brotli compression. + - "lz0": LZ0 compression. + - "lz4": LZ4 compression. + - "lz4_raw": LZ4_RAW compression. + - "zstd": Zstandard compression. + compression_level: Compression level to use. For ZSTD, the + recommended range is 1 to 22, with the default being 4. Higher levels + provide better compression but slower speed. + """ + # Convert string to Compression enum if necessary + if isinstance(compression, str): + compression = Compression.from_str(compression) + + if compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}: + if compression_level is None: + compression_level = compression.get_default_level() + + self.df.write_parquet(str(path), compression.value, compression_level) def write_json(self, path: str | pathlib.Path) -> None: """Execute the :py:class:`DataFrame` and write the results to a JSON file. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index b82f95e35..41a96ae6b 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -1107,14 +1107,24 @@ def test_write_compressed_parquet_wrong_compression_level( ) -@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"]) -def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compression): +@pytest.mark.parametrize("compression", ["wrong"]) +def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression): path = tmp_path with pytest.raises(ValueError): df.write_parquet(str(path), compression=compression) +@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"]) +def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression): + # Test write_parquet with zstd, brotli, gzip default compression level, + # ie don't specify compression level + # should complete without error + path = tmp_path + + df.write_parquet(str(path), compression=compression) + + def test_dataframe_export(df) -> None: # Guarantees that we have the canonical implementation # reading our dataframe export diff --git a/src/dataframe.rs b/src/dataframe.rs index e7d6ca6d6..9bdc2a327 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -463,7 +463,7 @@ impl PyDataFrame { /// Write a `DataFrame` to a Parquet file. #[pyo3(signature = ( path, - compression="uncompressed", + compression="zstd", compression_level=None ))] fn write_parquet(