|
35 | 35 |
|
36 | 36 | from datafusion._internal import DataFrame as DataFrameInternal |
37 | 37 | from datafusion.expr import Expr, SortExpr, sort_or_default |
| 38 | +from enum import Enum |
| 39 | +from typing import Tuple |
| 40 | + |
| 41 | + |
| 42 | +class Compression(Enum): |
| 43 | + UNCOMPRESSED = "uncompressed" |
| 44 | + SNAPPY = "snappy" |
| 45 | + GZIP = "gzip" |
| 46 | + BROTLI = "brotli" |
| 47 | + LZ4 = "lz4" |
| 48 | + LZ0 = "lz0" |
| 49 | + ZSTD = "zstd" |
| 50 | + LZ4_RAW = "lz4_raw" |
| 51 | + |
| 52 | + @classmethod |
| 53 | + def from_str(cls, value: str) -> "Compression": |
| 54 | + try: |
| 55 | + return cls(value.lower()) |
| 56 | + except ValueError: |
| 57 | + raise ValueError( |
| 58 | + f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}" |
| 59 | + ) |
| 60 | + |
| 61 | + def get_default_level(self) -> int: |
| 62 | + if self == Compression.GZIP: |
| 63 | + DEFAULT = 6 |
| 64 | + elif self == Compression.BROTLI: |
| 65 | + DEFAULT = 1 |
| 66 | + elif self == Compression.ZSTD: |
| 67 | + DEFAULT = 4 |
| 68 | + else: |
| 69 | + raise KeyError(f"{self.value} does not have a compression level.") |
| 70 | + return DEFAULT |
38 | 71 |
|
39 | 72 |
|
40 | 73 | class DataFrame: |
@@ -620,26 +653,34 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None |
620 | 653 | def write_parquet( |
621 | 654 | self, |
622 | 655 | path: str | pathlib.Path, |
623 | | - compression: str = "ZSTD", |
| 656 | + compression: str = Compression.ZSTD.value, |
624 | 657 | compression_level: int | None = None, |
625 | 658 | ) -> None: |
626 | 659 | """Execute the :py:class:`DataFrame` and write the results to a Parquet file. |
627 | 660 |
|
628 | 661 | Args: |
629 | 662 | path: Path of the Parquet file to write. |
630 | 663 | compression: Compression type to use. Default is "ZSTD". |
| 664 | + Available compression types are: |
| 665 | + - "UNCOMPRESSED": No compression. |
| 666 | + - "SNAPPY": Snappy compression. |
| 667 | + - "GZIP": Gzip compression. |
| 668 | + - "BROTLI": Brotli compression. |
| 669 | + - "LZ0": LZ0 compression. |
| 670 | + - "LZ4": LZ4 compression. |
| 671 | + - "LZ4_RAW": LZ4_RAW compression. |
| 672 | + - "ZSTD": Zstandard compression. |
631 | 673 | compression_level: Compression level to use. For ZSTD, the |
632 | 674 | recommended range is 1 to 22, with the default being 4. Higher levels |
633 | 675 | provide better compression but slower speed. |
634 | 676 | """ |
635 | | - if compression == "ZSTD": |
| 677 | + compression_enum = Compression.from_str(compression) |
| 678 | + |
| 679 | + if compression_enum in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}: |
636 | 680 | if compression_level is None: |
637 | | - # Default compression level for ZSTD is 4 like in delta-rs |
638 | | - # https://github.com/apache/datafusion-python/pull/981#discussion_r1899871918 |
639 | | - compression_level = 4 |
640 | | - elif not (1 <= compression_level <= 22): |
641 | | - raise ValueError("Compression level for ZSTD must be between 1 and 22") |
642 | | - self.df.write_parquet(str(path), compression, compression_level) |
| 681 | + compression_level = compression_enum.get_default_level() |
| 682 | + |
| 683 | + self.df.write_parquet(str(path), compression_enum.value, compression_level) |
643 | 684 |
|
644 | 685 | def write_json(self, path: str | pathlib.Path) -> None: |
645 | 686 | """Execute the :py:class:`DataFrame` and write the results to a JSON file. |
|
0 commit comments