Skip to content

Commit 811f633

Browse files
committed
refactor: simplify Compression enum methods and improve type handling in DataFrame.write_parquet
1 parent 67529b8 commit 811f633

File tree

1 file changed

+27
-28
lines changed

1 file changed

+27
-28
lines changed

python/datafusion/dataframe.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@
2121

2222
from __future__ import annotations
2323
import warnings
24-
from typing import Any, Iterable, List, TYPE_CHECKING, Literal, overload
24+
from typing import (
25+
Any,
26+
Iterable,
27+
List,
28+
TYPE_CHECKING,
29+
Literal,
30+
overload,
31+
Optional,
32+
Union,
33+
)
2534
from datafusion.record_batch import RecordBatchStream
2635
from typing_extensions import deprecated
2736
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -57,10 +66,7 @@ def from_str(cls, value: str) -> "Compression":
5766
"""Convert a string to a Compression enum value.
5867
5968
Args:
60-
value (str): The string representation of the compression type.
61-
62-
Returns:
63-
Compression: The corresponding Compression enum value.
69+
value: The string representation of the compression type.
6470
6571
Raises:
6672
ValueError: If the string does not match any Compression enum value.
@@ -72,28 +78,19 @@ def from_str(cls, value: str) -> "Compression":
7278
f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}"
7379
)
7480

75-
def get_default_level(self) -> int:
76-
"""Get the default compression level for the compression type.
77-
78-
Returns:
79-
int: The default compression level.
80-
81-
Raises:
82-
KeyError: If the compression type does not have a default level.
83-
"""
84-
# GZIP, BROTLI defaults from deltalake
81+
def get_default_level(self) -> Optional[int]:
82+
"""Get the default compression level for the compression type."""
83+
# GZIP, BROTLI default values from deltalake repo
8584
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
85+
# ZSTD default value from delta-rs
86+
# https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
8687
if self == Compression.GZIP:
87-
DEFAULT = 6
88+
return 6
8889
elif self == Compression.BROTLI:
89-
DEFAULT = 1
90+
return 1
9091
elif self == Compression.ZSTD:
91-
# ZSTD default from delta-rs
92-
# https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
93-
DEFAULT = 4
94-
else:
95-
raise KeyError(f"{self.value} does not have a compression level.")
96-
return DEFAULT
92+
return 4
93+
return None
9794

9895

9996
class DataFrame:
@@ -679,7 +676,7 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
679676
def write_parquet(
680677
self,
681678
path: str | pathlib.Path,
682-
compression: str = Compression.ZSTD.value,
679+
compression: Union[str, Compression] = Compression.ZSTD,
683680
compression_level: int | None = None,
684681
) -> None:
685682
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
@@ -700,13 +697,15 @@ def write_parquet(
700697
recommended range is 1 to 22, with the default being 4. Higher levels
701698
provide better compression but slower speed.
702699
"""
703-
compression_enum = Compression.from_str(compression)
700+
# Convert string to Compression enum if necessary
701+
if isinstance(compression, str):
702+
compression = Compression.from_str(compression)
704703

705-
if compression_enum in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}:
704+
if compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}:
706705
if compression_level is None:
707-
compression_level = compression_enum.get_default_level()
706+
compression_level = compression.get_default_level()
708707

709-
self.df.write_parquet(str(path), compression_enum.value, compression_level)
708+
self.df.write_parquet(str(path), compression.value, compression_level)
710709

711710
def write_json(self, path: str | pathlib.Path) -> None:
712711
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.

0 commit comments

Comments
 (0)