Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 81 additions & 5 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,65 @@

from datafusion._internal import DataFrame as DataFrameInternal
from datafusion.expr import Expr, SortExpr, sort_or_default
from enum import Enum


# excerpt from deltalake
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
class Compression(Enum):
"""Enum representing the available compression types for Parquet files."""

UNCOMPRESSED = "uncompressed"
SNAPPY = "snappy"
GZIP = "gzip"
BROTLI = "brotli"
LZ4 = "lz4"
LZ0 = "lz0"
ZSTD = "zstd"
LZ4_RAW = "lz4_raw"

@classmethod
def from_str(cls, value: str) -> "Compression":
"""Convert a string to a Compression enum value.

Args:
value (str): The string representation of the compression type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since the type hint indicates a str you shouldn't have to repeat here, per the google code design spec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good nit 😄


Returns:
Compression: The corresponding Compression enum value.

Raises:
ValueError: If the string does not match any Compression enum value.
"""
try:
return cls(value.lower())
except ValueError:
raise ValueError(
f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}"
)

def get_default_level(self) -> int:
"""Get the default compression level for the compression type.

Returns:
int: The default compression level.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: int not required since it's in the hint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good nit 😄


Raises:
KeyError: If the compression type does not have a default level.
"""
# GZIP, BROTLI defaults from deltalake
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
if self == Compression.GZIP:
DEFAULT = 6
elif self == Compression.BROTLI:
DEFAULT = 1
elif self == Compression.ZSTD:
# ZSTD default from delta-rs
# https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
DEFAULT = 4
else:
raise KeyError(f"{self.value} does not have a compression level.")
return DEFAULT


class DataFrame:
Expand Down Expand Up @@ -620,17 +679,34 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
def write_parquet(
self,
path: str | pathlib.Path,
compression: str = "uncompressed",
compression: str = Compression.ZSTD.value,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have this take as the type for compression std | Compression and do a quick check and get the value passed a Compression.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

compression_level: int | None = None,
) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.

Args:
path: Path of the Parquet file to write.
compression: Compression type to use.
compression_level: Compression level to use.
"""
self.df.write_parquet(str(path), compression, compression_level)
compression: Compression type to use. Default is "ZSTD".
Available compression types are:
- "uncompressed": No compression.
- "snappy": Snappy compression.
- "gzip": Gzip compression.
- "brotli": Brotli compression.
- "lz0": LZ0 compression.
- "lz4": LZ4 compression.
- "lz4_raw": LZ4_RAW compression.
- "zstd": Zstandard compression.
compression_level: Compression level to use. For ZSTD, the
recommended range is 1 to 22, with the default being 4. Higher levels
provide better compression but slower speed.
"""
compression_enum = Compression.from_str(compression)

if compression_enum in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}:
if compression_level is None:
compression_level = compression_enum.get_default_level()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing the checking here it would be slightly more ergonomic to just call compression_enum.get_default_level() and have it return None rather than raise an error. But I could also see how some would see calling get_default_level on the others as invalid. I'm not married to this idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This passes the None handling to Rust.
No tests broken, so this is a good ergonomic suggestion.


self.df.write_parquet(str(path), compression_enum.value, compression_level)

def write_json(self, path: str | pathlib.Path) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.
Expand Down
18 changes: 13 additions & 5 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,7 @@ def test_optimized_logical_plan(aggregate_df):
def test_execution_plan(aggregate_df):
plan = aggregate_df.execution_plan()

expected = (
"AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" # noqa: E501
)
expected = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" # noqa: E501

assert expected == plan.display()

Expand Down Expand Up @@ -1107,14 +1105,24 @@ def test_write_compressed_parquet_wrong_compression_level(
)


@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"])
def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compression):
@pytest.mark.parametrize("compression", ["wrong"])
def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression):
path = tmp_path

with pytest.raises(ValueError):
df.write_parquet(str(path), compression=compression)


# Test write_parquet with zstd, brotli, gzip default compression level,
# ie don't specify compression level
# should complete without error
@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"])
def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression):
path = tmp_path

df.write_parquet(str(path), compression=compression)


def test_dataframe_export(df) -> None:
# Guarantees that we have the canonical implementation
# reading our dataframe export
Expand Down
2 changes: 1 addition & 1 deletion src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ impl PyDataFrame {
/// Write a `DataFrame` to a Parquet file.
#[pyo3(signature = (
path,
compression="uncompressed",
compression="zstd",
compression_level=None
))]
fn write_parquet(
Expand Down
Loading