Skip to content

Commit df7d65e

Browse files
committed
feat: implement Compression enum and update write_parquet method to use it
1 parent 56965f4 commit df7d65e

File tree

1 file changed

+49
-8
lines changed

1 file changed

+49
-8
lines changed

python/datafusion/dataframe.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,39 @@
3535

3636
from datafusion._internal import DataFrame as DataFrameInternal
3737
from datafusion.expr import Expr, SortExpr, sort_or_default
38+
from enum import Enum
39+
from typing import Tuple
40+
41+
42+
class Compression(Enum):
43+
UNCOMPRESSED = "uncompressed"
44+
SNAPPY = "snappy"
45+
GZIP = "gzip"
46+
BROTLI = "brotli"
47+
LZ4 = "lz4"
48+
LZ0 = "lz0"
49+
ZSTD = "zstd"
50+
LZ4_RAW = "lz4_raw"
51+
52+
@classmethod
53+
def from_str(cls, value: str) -> "Compression":
54+
try:
55+
return cls(value.lower())
56+
except ValueError:
57+
raise ValueError(
58+
f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}"
59+
)
60+
61+
def get_default_level(self) -> int:
62+
if self == Compression.GZIP:
63+
DEFAULT = 6
64+
elif self == Compression.BROTLI:
65+
DEFAULT = 1
66+
elif self == Compression.ZSTD:
67+
DEFAULT = 4
68+
else:
69+
raise KeyError(f"{self.value} does not have a compression level.")
70+
return DEFAULT
3871

3972

4073
class DataFrame:
@@ -620,26 +653,34 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
620653
def write_parquet(
621654
self,
622655
path: str | pathlib.Path,
623-
compression: str = "ZSTD",
656+
compression: str = Compression.ZSTD.value,
624657
compression_level: int | None = None,
625658
) -> None:
626659
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
627660
628661
Args:
629662
path: Path of the Parquet file to write.
630663
compression: Compression type to use. Default is "ZSTD".
664+
Available compression types are:
665+
- "UNCOMPRESSED": No compression.
666+
- "SNAPPY": Snappy compression.
667+
- "GZIP": Gzip compression.
668+
- "BROTLI": Brotli compression.
669+
- "LZ0": LZ0 compression.
670+
- "LZ4": LZ4 compression.
671+
- "LZ4_RAW": LZ4_RAW compression.
672+
- "ZSTD": Zstandard compression.
631673
compression_level: Compression level to use. For ZSTD, the
632674
recommended range is 1 to 22, with the default being 4. Higher levels
633675
provide better compression but slower speed.
634676
"""
635-
if compression == "ZSTD":
677+
compression_enum = Compression.from_str(compression)
678+
679+
if compression_enum in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}:
636680
if compression_level is None:
637-
# Default compression level for ZSTD is 4 like in delta-rs
638-
# https://github.com/apache/datafusion-python/pull/981#discussion_r1899871918
639-
compression_level = 4
640-
elif not (1 <= compression_level <= 22):
641-
raise ValueError("Compression level for ZSTD must be between 1 and 22")
642-
self.df.write_parquet(str(path), compression, compression_level)
681+
compression_level = compression_enum.get_default_level()
682+
683+
self.df.write_parquet(str(path), compression_enum.value, compression_level)
643684

644685
def write_json(self, path: str | pathlib.Path) -> None:
645686
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.

0 commit comments

Comments
 (0)