2121
2222from __future__ import annotations
2323import warnings
24- from typing import Any , Iterable , List , TYPE_CHECKING , Literal , overload
24+ from typing import (
25+ Any ,
26+ Iterable ,
27+ List ,
28+ TYPE_CHECKING ,
29+ Literal ,
30+ overload ,
31+ Optional ,
32+ Union ,
33+ )
2534from datafusion .record_batch import RecordBatchStream
2635from typing_extensions import deprecated
2736from datafusion .plan import LogicalPlan , ExecutionPlan
@@ -57,10 +66,7 @@ def from_str(cls, value: str) -> "Compression":
5766 """Convert a string to a Compression enum value.
5867
5968 Args:
60- value (str): The string representation of the compression type.
61-
62- Returns:
63- Compression: The corresponding Compression enum value.
69+ value: The string representation of the compression type.
6470
6571 Raises:
6672 ValueError: If the string does not match any Compression enum value.
@@ -72,28 +78,19 @@ def from_str(cls, value: str) -> "Compression":
7278 f"{ value } is not a valid Compression. Valid values are: { [item .value for item in Compression ]} "
7379 )
7480
75- def get_default_level (self ) -> int :
76- """Get the default compression level for the compression type.
77-
78- Returns:
79- int: The default compression level.
80-
81- Raises:
82- KeyError: If the compression type does not have a default level.
83- """
84- # GZIP, BROTLI defaults from deltalake
81+ def get_default_level (self ) -> Optional [int ]:
82+ """Get the default compression level for the compression type."""
83+ # GZIP, BROTLI default values from deltalake repo
8584 # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
85+ # ZSTD default value from delta-rs
86+ # https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
8687 if self == Compression .GZIP :
87- DEFAULT = 6
88+ return 6
8889 elif self == Compression .BROTLI :
89- DEFAULT = 1
90+ return 1
9091 elif self == Compression .ZSTD :
91- # ZSTD default from delta-rs
92- # https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
93- DEFAULT = 4
94- else :
95- raise KeyError (f"{ self .value } does not have a compression level." )
96- return DEFAULT
92+ return 4
93+ return None
9794
9895
9996class DataFrame :
@@ -679,7 +676,7 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
679676 def write_parquet (
680677 self ,
681678 path : str | pathlib .Path ,
682- compression : str = Compression .ZSTD . value ,
679+ compression : Union [ str , Compression ] = Compression .ZSTD ,
683680 compression_level : int | None = None ,
684681 ) -> None :
685682 """Execute the :py:class:`DataFrame` and write the results to a Parquet file.
@@ -700,13 +697,15 @@ def write_parquet(
700697 recommended range is 1 to 22, with the default being 4. Higher levels
701698 provide better compression but slower speed.
702699 """
703- compression_enum = Compression .from_str (compression )
700+ # Convert string to Compression enum if necessary
701+ if isinstance (compression , str ):
702+ compression = Compression .from_str (compression )
704703
705- if compression_enum in {Compression .GZIP , Compression .BROTLI , Compression .ZSTD }:
704+ if compression in {Compression .GZIP , Compression .BROTLI , Compression .ZSTD }:
706705 if compression_level is None :
707- compression_level = compression_enum .get_default_level ()
706+ compression_level = compression .get_default_level ()
708707
709- self .df .write_parquet (str (path ), compression_enum .value , compression_level )
708+ self .df .write_parquet (str (path ), compression .value , compression_level )
710709
711710 def write_json (self , path : str | pathlib .Path ) -> None :
712711 """Execute the :py:class:`DataFrame` and write the results to a JSON file.
0 commit comments