Skip to content

Commit 0d191f6

Browse files
committed
more testing around writer options
1 parent 6720ee5 commit 0d191f6

File tree

4 files changed

+92
-72
lines changed

4 files changed

+92
-72
lines changed

python/datafusion/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@
4545
SessionContext,
4646
SQLOptions,
4747
)
48-
from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions
48+
from .dataframe import (
49+
DataFrame,
50+
DataFrameWriteOptions,
51+
InsertOp,
52+
ParquetColumnOptions,
53+
ParquetWriterOptions,
54+
)
4955
from .dataframe_formatter import configure_formatter
5056
from .expr import (
5157
Expr,
@@ -75,9 +81,11 @@
7581
"Config",
7682
"DFSchema",
7783
"DataFrame",
84+
"DataFrameWriteOptions",
7885
"Database",
7986
"ExecutionPlan",
8087
"Expr",
88+
"InsertOp",
8189
"LogicalPlan",
8290
"ParquetColumnOptions",
8391
"ParquetWriterOptions",

python/datafusion/dataframe.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,10 @@ def write_csv(
941941
with_header: If true, output the CSV header row.
942942
write_options: Options that impact how the DataFrame is written.
943943
"""
944-
self.df.write_csv(str(path), with_header, write_options._raw_write_options)
944+
raw_write_options = (
945+
write_options._raw_write_options if write_options is not None else None
946+
)
947+
self.df.write_csv(str(path), with_header, raw_write_options)
945948

946949
@overload
947950
def write_parquet(
@@ -1013,11 +1016,14 @@ def write_parquet(
10131016
):
10141017
compression_level = compression.get_default_level()
10151018

1019+
raw_write_options = (
1020+
write_options._raw_write_options if write_options is not None else None
1021+
)
10161022
self.df.write_parquet(
10171023
str(path),
10181024
compression.value,
10191025
compression_level,
1020-
write_options._raw_write_options,
1026+
raw_write_options,
10211027
)
10221028

10231029
def write_parquet_with_options(
@@ -1070,11 +1076,14 @@ def write_parquet_with_options(
10701076
bloom_filter_ndv=opts.bloom_filter_ndv,
10711077
)
10721078

1079+
raw_write_options = (
1080+
write_options._raw_write_options if write_options is not None else None
1081+
)
10731082
self.df.write_parquet_with_options(
10741083
str(path),
10751084
options_internal,
10761085
column_specific_options_internal,
1077-
write_options._raw_write_options,
1086+
raw_write_options,
10781087
)
10791088

10801089
def write_json(
@@ -1088,7 +1097,10 @@ def write_json(
10881097
path: Path of the JSON file to write.
10891098
write_options: Options that impact how the DataFrame is written.
10901099
"""
1091-
self.df.write_json(str(path), write_options=write_options._raw_write_options)
1100+
raw_write_options = (
1101+
write_options._raw_write_options if write_options is not None else None
1102+
)
1103+
self.df.write_json(str(path), write_options=raw_write_options)
10921104

10931105
def write_table(
10941106
self, table_name: str, write_options: DataFrameWriteOptions | None = None
@@ -1099,7 +1111,10 @@ def write_table(
10991111
Not all table providers support writing operations. See the individual
11001112
implementations for details.
11011113
"""
1102-
self.df.write_table(table_name, write_options._raw_write_options)
1114+
raw_write_options = (
1115+
write_options._raw_write_options if write_options is not None else None
1116+
)
1117+
self.df.write_table(table_name, raw_write_options)
11031118

11041119
def to_arrow_table(self) -> pa.Table:
11051120
"""Execute the :py:class:`DataFrame` and convert it into an Arrow Table.
@@ -1284,17 +1299,11 @@ def __init__(
12841299
sort_by: Expr | SortExpr | Sequence[Expr] | Sequence[SortExpr] | None = None,
12851300
) -> None:
12861301
"""Instantiate writer options for DataFrame."""
1287-
write_options = DataFrameWriteOptionsInternal()
1288-
if insert_operation is not None:
1289-
write_options = write_options.with_insert_operation(insert_operation.value)
1290-
write_options = write_options.with_single_file_output(single_file_output)
1291-
if partition_by is not None:
1292-
if isinstance(partition_by, str):
1293-
partition_by = [partition_by]
1294-
write_options = write_options.with_partition_by(partition_by)
1302+
if isinstance(partition_by, str):
1303+
partition_by = [partition_by]
12951304

12961305
sort_by_raw = sort_list_to_raw_sort_list(sort_by)
1297-
if sort_by_raw is not None:
1298-
write_options = write_options.with_sort_by(sort_by_raw)
12991306

1300-
self._raw_write_options = write_options
1307+
self._raw_write_options = DataFrameWriteOptionsInternal(
1308+
insert_operation, single_file_output, partition_by, sort_by_raw
1309+
)

python/tests/test_dataframe.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import ctypes
1818
import datetime
19+
import itertools
1920
import os
2021
import re
2122
import threading
@@ -59,9 +60,7 @@ def ctx():
5960

6061

6162
@pytest.fixture
62-
def df():
63-
ctx = SessionContext()
64-
63+
def df(ctx):
6564
# create a RecordBatch and a new DataFrame from it
6665
batch = pa.RecordBatch.from_arrays(
6766
[pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])],
@@ -1831,29 +1830,52 @@ def test_write_csv(ctx, df, tmp_path, path_to_str):
18311830
assert result == expected
18321831

18331832

1833+
sort_by_cases = [
1834+
(None, [1, 2, 3], "unsorted"),
1835+
(column("c"), [2, 1, 3], "single_column_expr"),
1836+
(column("a").sort(ascending=False), [3, 2, 1], "single_sort_expr"),
1837+
([column("c"), column("b")], [2, 1, 3], "list_col_expr"),
1838+
(
1839+
[column("c").sort(ascending=False), column("b").sort(ascending=False)],
1840+
[3, 1, 2],
1841+
"list_sort_expr",
1842+
),
1843+
]
1844+
1845+
formats = ["csv", "json", "parquet", "table"]
1846+
1847+
18341848
@pytest.mark.parametrize(
1835-
("sort_by", "expected_a"),
1849+
("format", "sort_by", "expected_a"),
18361850
[
1837-
pytest.param(None, [1, 2, 3], id="unsorted"),
1838-
pytest.param(column("c"), [2, 1, 3], id="single_column_expr"),
1839-
pytest.param(
1840-
column("a").sort(ascending=False), [3, 2, 1], id="single_sort_expr"
1841-
),
1842-
pytest.param([column("c"), column("b")], [2, 1, 3], id="list_col_expr"),
1843-
pytest.param(
1844-
[column("c").sort(ascending=False), column("b").sort(ascending=False)],
1845-
[3, 1, 2],
1846-
id="list_sort_expr",
1847-
),
1851+
pytest.param(format, sort_by, expected_a, id=f"{format}_{test_id}")
1852+
for format, (sort_by, expected_a, test_id) in itertools.product(
1853+
formats, sort_by_cases
1854+
)
18481855
],
18491856
)
1850-
def test_write_csv_with_options(ctx, df, tmp_path, sort_by, expected_a) -> None:
1857+
def test_write_files_with_options(
1858+
ctx, df, tmp_path, format, sort_by, expected_a
1859+
) -> None:
18511860
write_options = DataFrameWriteOptions(sort_by=sort_by)
1852-
df.write_csv(tmp_path, with_header=True, write_options=write_options)
18531861

1854-
ctx.register_csv("csv", tmp_path)
1855-
result = ctx.table("csv").to_pydict()["a"]
1856-
ctx.table("csv").show()
1862+
if format == "csv":
1863+
df.write_csv(tmp_path, with_header=True, write_options=write_options)
1864+
ctx.register_csv("test_table", tmp_path)
1865+
elif format == "json":
1866+
df.write_json(tmp_path, write_options=write_options)
1867+
ctx.register_json("test_table", tmp_path)
1868+
elif format == "parquet":
1869+
df.write_parquet(tmp_path, write_options=write_options)
1870+
ctx.register_parquet("test_table", tmp_path)
1871+
elif format == "table":
1872+
batch = pa.RecordBatch.from_arrays([[], [], []], schema=df.schema())
1873+
ctx.register_record_batches("test_table", [[batch]])
1874+
ctx.table("test_table").show()
1875+
df.write_table("test_table", write_options=write_options)
1876+
1877+
result = ctx.table("test_table").to_pydict()["a"]
1878+
ctx.table("test_table").show()
18571879

18581880
assert result == expected_a
18591881

src/dataframe.rs

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,44 +1078,25 @@ impl From<PyDataFrameWriteOptions> for DataFrameWriteOptions {
10781078
#[pymethods]
10791079
impl PyDataFrameWriteOptions {
10801080
#[new]
1081-
fn new() -> Self {
1081+
fn new(
1082+
insert_operation: Option<PyInsertOp>,
1083+
single_file_output: bool,
1084+
partition_by: Option<Vec<String>>,
1085+
sort_by: Option<Vec<PySortExpr>>,
1086+
) -> Self {
1087+
let insert_operation = insert_operation.map(Into::into).unwrap_or(InsertOp::Append);
1088+
let sort_by = sort_by
1089+
.unwrap_or_default()
1090+
.into_iter()
1091+
.map(Into::into)
1092+
.collect();
10821093
Self {
1083-
insert_operation: InsertOp::Append,
1084-
single_file_output: false,
1085-
partition_by: vec![],
1086-
sort_by: vec![],
1094+
insert_operation,
1095+
single_file_output,
1096+
partition_by: partition_by.unwrap_or_default(),
1097+
sort_by,
10871098
}
10881099
}
1089-
1090-
pub fn with_insert_operation(&self, insert_operation: PyInsertOp) -> Self {
1091-
let mut result = self.clone();
1092-
1093-
result.insert_operation = insert_operation.into();
1094-
result
1095-
}
1096-
1097-
pub fn with_single_file_output(&self, single_file_output: bool) -> Self {
1098-
let mut result = self.clone();
1099-
1100-
result.single_file_output = single_file_output;
1101-
result
1102-
}
1103-
1104-
/// Sets the partition_by columns for output partitioning
1105-
pub fn with_partition_by(&self, partition_by: Vec<String>) -> Self {
1106-
let mut result = self.clone();
1107-
1108-
result.partition_by = partition_by;
1109-
result
1110-
}
1111-
1112-
/// Sets the sort_by columns for output sorting
1113-
pub fn with_sort_by(&self, sort_by: Vec<PySortExpr>) -> Self {
1114-
let mut result = self.clone();
1115-
1116-
result.sort_by = sort_by.into_iter().map(Into::into).collect();
1117-
result
1118-
}
11191100
}
11201101

11211102
/// Print DataFrame

0 commit comments

Comments
 (0)