Skip to content

Commit d416a68

Browse files
committed
add csv write unit test
1 parent a8ccbcd commit d416a68

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

python/datafusion/dataframe.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ def write_csv(
941941
with_header: If true, output the CSV header row.
942942
write_options: Options that impact how the DataFrame is written.
943943
"""
944-
self.df.write_csv(str(path), with_header, write_options=write_options)
944+
self.df.write_csv(str(path), with_header, write_options._raw_write_options)
945945

946946
@overload
947947
def write_parquet(
@@ -1014,7 +1014,10 @@ def write_parquet(
10141014
compression_level = compression.get_default_level()
10151015

10161016
self.df.write_parquet(
1017-
str(path), compression.value, compression_level, write_options
1017+
str(path),
1018+
compression.value,
1019+
compression_level,
1020+
write_options._raw_write_options,
10181021
)
10191022

10201023
def write_parquet_with_options(
@@ -1071,7 +1074,7 @@ def write_parquet_with_options(
10711074
str(path),
10721075
options_internal,
10731076
column_specific_options_internal,
1074-
write_options,
1077+
write_options._raw_write_options,
10751078
)
10761079

10771080
def write_json(
@@ -1085,7 +1088,7 @@ def write_json(
10851088
path: Path of the JSON file to write.
10861089
write_options: Options that impact how the DataFrame is written.
10871090
"""
1088-
self.df.write_json(str(path), write_options=write_options)
1091+
self.df.write_json(str(path), write_options=write_options._raw_write_options)
10891092

10901093
def write_table(
10911094
self, table_name: str, write_options: DataFrameWriteOptions | None = None
@@ -1096,7 +1099,7 @@ def write_table(
10961099
Not all table providers support writing operations. See the individual
10971100
implementations for details.
10981101
"""
1099-
self.df.write_table(table_name, write_options)
1102+
self.df.write_table(table_name, write_options._raw_write_options)
11001103

11011104
def to_arrow_table(self) -> pa.Table:
11021105
"""Execute the :py:class:`DataFrame` and convert it into an Arrow Table.
@@ -1275,11 +1278,11 @@ def __init__(
12751278
"""Instantiate writer options for DataFrame."""
12761279
write_options = DataFrameWriteOptionsInternal()
12771280
if insert_operation is not None:
1278-
write_options = write_options.with_insert_operation(insert_operation)
1281+
write_options = write_options.with_insert_operation(insert_operation.value)
12791282
write_options = write_options.with_single_file_output(single_file_output)
12801283
if partition_by is not None:
12811284
if isinstance(partition_by, str):
1282-
partition_by = [single_file_output]
1285+
partition_by = [partition_by]
12831286
write_options = write_options.with_partition_by(partition_by)
12841287

12851288
sort_by_raw = sort_list_to_raw_sort_list(sort_by)

python/tests/test_dataframe.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from datafusion import (
4141
functions as f,
4242
)
43+
from datafusion.dataframe import DataFrameWriteOptions
4344
from datafusion.dataframe_formatter import (
4445
DataFrameHtmlFormatter,
4546
configure_formatter,
@@ -1830,6 +1831,33 @@ def test_write_csv(ctx, df, tmp_path, path_to_str):
18301831
assert result == expected
18311832

18321833

1834+
@pytest.mark.parametrize(
1835+
("sort_by", "expected_a"),
1836+
[
1837+
pytest.param(None, [1, 2, 3], id="unsorted"),
1838+
pytest.param(column("c"), [2, 1, 3], id="single_column_expr"),
1839+
pytest.param(
1840+
column("a").sort(ascending=False), [3, 2, 1], id="single_sort_expr"
1841+
),
1842+
pytest.param([column("c"), column("b")], [2, 1, 3], id="list_col_expr"),
1843+
pytest.param(
1844+
[column("c").sort(ascending=False), column("b").sort(ascending=False)],
1845+
[3, 1, 2],
1846+
id="list_sort_expr",
1847+
),
1848+
],
1849+
)
1850+
def test_write_csv_with_options(ctx, df, tmp_path, sort_by, expected_a) -> None:
1851+
write_options = DataFrameWriteOptions(sort_by=sort_by)
1852+
df.write_csv(tmp_path, with_header=True, write_options=write_options)
1853+
1854+
ctx.register_csv("csv", tmp_path)
1855+
result = ctx.table("csv").to_pydict()["a"]
1856+
ctx.table("csv").show()
1857+
1858+
assert result == expected_a
1859+
1860+
18331861
@pytest.mark.parametrize("path_to_str", [True, False])
18341862
def test_write_json(ctx, df, tmp_path, path_to_str):
18351863
path = str(tmp_path) if path_to_str else tmp_path

src/dataframe.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,10 +746,10 @@ impl PyDataFrame {
746746
/// Write a `DataFrame` to a CSV file.
747747
fn write_csv(
748748
&self,
749+
py: Python,
749750
path: &str,
750751
with_header: bool,
751752
write_options: Option<PyDataFrameWriteOptions>,
752-
py: Python,
753753
) -> PyDataFusionResult<()> {
754754
let csv_options = CsvOptions {
755755
has_header: Some(with_header),
@@ -1078,15 +1078,22 @@ impl From<PyDataFrameWriteOptions> for DataFrameWriteOptions {
10781078
#[pymethods]
10791079
impl PyDataFrameWriteOptions {
10801080
#[new]
1081-
fn new(insert_operation: PyInsertOp) -> Self {
1081+
fn new() -> Self {
10821082
Self {
1083-
insert_operation: insert_operation.into(),
1083+
insert_operation: InsertOp::Append,
10841084
single_file_output: false,
10851085
partition_by: vec![],
10861086
sort_by: vec![],
10871087
}
10881088
}
10891089

1090+
pub fn with_insert_operation(&self, insert_operation: PyInsertOp) -> Self {
1091+
let mut result = self.clone();
1092+
1093+
result.insert_operation = insert_operation.into();
1094+
result
1095+
}
1096+
10901097
pub fn with_single_file_output(&self, single_file_output: bool) -> Self {
10911098
let mut result = self.clone();
10921099

0 commit comments

Comments
 (0)