Skip to content

Commit 819451f

Browse files
committed
Initial commit for dataframe write_table
1 parent e75addf commit 819451f

File tree

4 files changed

+212
-16
lines changed

4 files changed

+212
-16
lines changed

python/datafusion/dataframe.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@
3939
from typing_extensions import deprecated # Python 3.12
4040

4141
from datafusion._internal import DataFrame as DataFrameInternal
42+
from datafusion._internal import DataFrameWriteOptions as DataFrameWriteOptionsInternal
43+
from datafusion._internal import InsertOp as InsertOpInternal
4244
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4345
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4446
from datafusion.expr import (
4547
Expr,
48+
SortExpr,
4649
SortKey,
4750
ensure_expr,
4851
ensure_expr_list,
@@ -1060,6 +1063,17 @@ def write_json(self, path: str | pathlib.Path) -> None:
10601063
"""
10611064
self.df.write_json(str(path))
10621065

1066+
def write_table(
1067+
self, table_name: str, write_options: DataFrameWriteOptions | None = None
1068+
) -> None:
1069+
"""Execute the :py:class:`DataFrame` and write the results to a table.
1070+
1071+
The table must be registered with the session to perform this operation.
1072+
Not all table providers support writing operations. See the individual
1073+
implementations for details.
1074+
"""
1075+
self.df.write_table(table_name, write_options)
1076+
10631077
def to_arrow_table(self) -> pa.Table:
10641078
"""Execute the :py:class:`DataFrame` and convert it into an Arrow Table.
10651079
@@ -1206,3 +1220,46 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> DataFrame:
12061220
- For columns not in subset, the original column is kept unchanged
12071221
"""
12081222
return DataFrame(self.df.fill_null(value, subset))
1223+
1224+
1225+
class InsertOp(Enum):
1226+
"""Insert operation mode.
1227+
1228+
These modes are used by the table writing feature to define how record
1229+
batches should be written to a table.
1230+
"""
1231+
1232+
APPEND = InsertOpInternal.APPEND
1233+
REPLACE = InsertOpInternal.REPLACE
1234+
OVERWRITE = InsertOpInternal.OVERWRITE
1235+
1236+
1237+
class DataFrameWriteOptions:
1238+
"""Writer options for DataFrame.
1239+
1240+
There is no guarantee the table provider supports all writer options.
1241+
See the individual implementation and documentation for details.
1242+
"""
1243+
1244+
def __init__(
1245+
self,
1246+
insert_operation: InsertOp | None = None,
1247+
single_file_output: bool = False,
1248+
partition_by: str | Sequence[str] | None = None,
1249+
sort_by: Expr | SortExpr | Sequence[Expr] | Sequence[SortExpr] | None = None,
1250+
) -> None:
1251+
"""Instantiate writer options for DataFrame."""
1252+
write_options = DataFrameWriteOptionsInternal()
1253+
if insert_operation is not None:
1254+
write_options = write_options.with_insert_operation(insert_operation)
1255+
write_options = write_options.with_single_file_output(single_file_output)
1256+
if partition_by is not None:
1257+
if isinstance(partition_by, str):
1258+
partition_by = [single_file_output]
1259+
write_options = write_options.with_partition_by(partition_by)
1260+
1261+
sort_by_raw = sort_list_to_raw_sort_list(sort_by)
1262+
if sort_by_raw is not None:
1263+
write_options = write_options.with_sort_by(sort_by_raw)
1264+
1265+
self._raw_write_options = write_options

python/tests/test_dataframe.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,6 +2322,25 @@ def test_write_parquet_options_error(df, tmp_path):
23222322
df.write_parquet(str(tmp_path), options, compression_level=1)
23232323

23242324

2325+
def test_write_table(ctx, df):
2326+
batch = pa.RecordBatch.from_arrays(
2327+
[pa.array([1, 2, 3])],
2328+
names=["a"],
2329+
)
2330+
2331+
ctx.register_record_batches("t", [[batch]])
2332+
2333+
df = ctx.table("t").with_column("a", column("a") * literal(-1))
2334+
2335+
ctx.table("t").show()
2336+
2337+
df.write_table("t")
2338+
result = ctx.table("t").sort(column("a")).collect()[0][0].to_pylist()
2339+
expected = [-3, -2, -1, 1, 2, 3]
2340+
2341+
assert result == expected
2342+
2343+
23252344
def test_dataframe_export(df) -> None:
23262345
# Guarantees that we have the canonical implementation
23272346
# reading our dataframe export

src/dataframe.rs

Lines changed: 134 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3434
use datafusion::datasource::TableProvider;
3535
use datafusion::error::DataFusionError;
3636
use datafusion::execution::SendableRecordBatchStream;
37+
use datafusion::logical_expr::dml::InsertOp;
38+
use datafusion::logical_expr::SortExpr;
3739
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3840
use datafusion::prelude::*;
3941
use datafusion_ffi::table_provider::FFI_TableProvider;
@@ -742,18 +744,27 @@ impl PyDataFrame {
742744
}
743745

744746
/// Write a `DataFrame` to a CSV file.
745-
fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
747+
fn write_csv(
748+
&self,
749+
path: &str,
750+
with_header: bool,
751+
write_options: Option<PyDataFrameWriteOptions>,
752+
py: Python,
753+
) -> PyDataFusionResult<()> {
746754
let csv_options = CsvOptions {
747755
has_header: Some(with_header),
748756
..Default::default()
749757
};
758+
let write_options = write_options
759+
.map(DataFrameWriteOptions::from)
760+
.unwrap_or_default();
761+
750762
wait_for_future(
751763
py,
752-
self.df.as_ref().clone().write_csv(
753-
path,
754-
DataFrameWriteOptions::new(),
755-
Some(csv_options),
756-
),
764+
self.df
765+
.as_ref()
766+
.clone()
767+
.write_csv(path, write_options, Some(csv_options)),
757768
)??;
758769
Ok(())
759770
}
@@ -762,13 +773,15 @@ impl PyDataFrame {
762773
#[pyo3(signature = (
763774
path,
764775
compression="zstd",
765-
compression_level=None
776+
compression_level=None,
777+
write_options=None,
766778
))]
767779
fn write_parquet(
768780
&self,
769781
path: &str,
770782
compression: &str,
771783
compression_level: Option<u32>,
784+
write_options: Option<PyDataFrameWriteOptions>,
772785
py: Python,
773786
) -> PyDataFusionResult<()> {
774787
fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
@@ -807,14 +820,16 @@ impl PyDataFrame {
807820

808821
let mut options = TableParquetOptions::default();
809822
options.global.compression = Some(compression_string);
823+
let write_options = write_options
824+
.map(DataFrameWriteOptions::from)
825+
.unwrap_or_default();
810826

811827
wait_for_future(
812828
py,
813-
self.df.as_ref().clone().write_parquet(
814-
path,
815-
DataFrameWriteOptions::new(),
816-
Option::from(options),
817-
),
829+
self.df
830+
.as_ref()
831+
.clone()
832+
.write_parquet(path, write_options, Option::from(options)),
818833
)??;
819834
Ok(())
820835
}
@@ -825,6 +840,7 @@ impl PyDataFrame {
825840
path: &str,
826841
options: PyParquetWriterOptions,
827842
column_specific_options: HashMap<String, PyParquetColumnOptions>,
843+
write_options: Option<PyDataFrameWriteOptions>,
828844
py: Python,
829845
) -> PyDataFusionResult<()> {
830846
let table_options = TableParquetOptions {
@@ -835,26 +851,55 @@ impl PyDataFrame {
835851
.collect(),
836852
..Default::default()
837853
};
838-
854+
let write_options = write_options
855+
.map(DataFrameWriteOptions::from)
856+
.unwrap_or_default();
839857
wait_for_future(
840858
py,
841859
self.df.as_ref().clone().write_parquet(
842860
path,
843-
DataFrameWriteOptions::new(),
861+
write_options,
844862
Option::from(table_options),
845863
),
846864
)??;
847865
Ok(())
848866
}
849867

850868
/// Executes a query and writes the results to a partitioned JSON file.
851-
fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
869+
fn write_json(
870+
&self,
871+
path: &str,
872+
py: Python,
873+
write_options: Option<PyDataFrameWriteOptions>,
874+
) -> PyDataFusionResult<()> {
875+
let write_options = write_options
876+
.map(DataFrameWriteOptions::from)
877+
.unwrap_or_default();
852878
wait_for_future(
853879
py,
854880
self.df
855881
.as_ref()
856882
.clone()
857-
.write_json(path, DataFrameWriteOptions::new(), None),
883+
.write_json(path, write_options, None),
884+
)??;
885+
Ok(())
886+
}
887+
888+
fn write_table(
889+
&self,
890+
py: Python,
891+
table_name: &str,
892+
write_options: Option<PyDataFrameWriteOptions>,
893+
) -> PyDataFusionResult<()> {
894+
let write_options = write_options
895+
.map(DataFrameWriteOptions::from)
896+
.unwrap_or_default();
897+
wait_for_future(
898+
py,
899+
self.df
900+
.as_ref()
901+
.clone()
902+
.write_table(table_name, write_options),
858903
)??;
859904
Ok(())
860905
}
@@ -993,6 +1038,79 @@ impl PyDataFrame {
9931038
}
9941039
}
9951040

1041+
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
1042+
#[pyclass(eq, eq_int, name = "InsertOp", module = "datafusion")]
1043+
pub enum PyInsertOp {
1044+
APPEND,
1045+
REPLACE,
1046+
OVERWRITE,
1047+
}
1048+
1049+
impl From<PyInsertOp> for InsertOp {
1050+
fn from(value: PyInsertOp) -> Self {
1051+
match value {
1052+
PyInsertOp::APPEND => InsertOp::Append,
1053+
PyInsertOp::REPLACE => InsertOp::Replace,
1054+
PyInsertOp::OVERWRITE => InsertOp::Overwrite,
1055+
}
1056+
}
1057+
}
1058+
1059+
#[derive(Debug, Clone)]
1060+
#[pyclass(name = "DataFrameWriteOptions", module = "datafusion")]
1061+
pub struct PyDataFrameWriteOptions {
1062+
insert_operation: InsertOp,
1063+
single_file_output: bool,
1064+
partition_by: Vec<String>,
1065+
sort_by: Vec<SortExpr>,
1066+
}
1067+
1068+
impl From<PyDataFrameWriteOptions> for DataFrameWriteOptions {
1069+
fn from(value: PyDataFrameWriteOptions) -> Self {
1070+
DataFrameWriteOptions::new()
1071+
.with_insert_operation(value.insert_operation)
1072+
.with_single_file_output(value.single_file_output)
1073+
.with_partition_by(value.partition_by)
1074+
.with_sort_by(value.sort_by)
1075+
}
1076+
}
1077+
1078+
#[pymethods]
1079+
impl PyDataFrameWriteOptions {
1080+
#[new]
1081+
fn new(insert_operation: PyInsertOp) -> Self {
1082+
Self {
1083+
insert_operation: insert_operation.into(),
1084+
single_file_output: false,
1085+
partition_by: vec![],
1086+
sort_by: vec![],
1087+
}
1088+
}
1089+
1090+
pub fn with_single_file_output(&self, single_file_output: bool) -> Self {
1091+
let mut result = self.clone();
1092+
1093+
result.single_file_output = single_file_output;
1094+
result
1095+
}
1096+
1097+
/// Sets the partition_by columns for output partitioning
1098+
pub fn with_partition_by(&self, partition_by: Vec<String>) -> Self {
1099+
let mut result = self.clone();
1100+
1101+
result.partition_by = partition_by;
1102+
result
1103+
}
1104+
1105+
/// Sets the sort_by columns for output sorting
1106+
pub fn with_sort_by(&self, sort_by: Vec<PySortExpr>) -> Self {
1107+
let mut result = self.clone();
1108+
1109+
result.sort_by = sort_by.into_iter().map(Into::into).collect();
1110+
result
1111+
}
1112+
}
1113+
9961114
/// Print DataFrame
9971115
fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
9981116
// Get string representation of record batches

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
8686
m.add_class::<context::PySessionContext>()?;
8787
m.add_class::<context::PySQLOptions>()?;
8888
m.add_class::<dataframe::PyDataFrame>()?;
89+
m.add_class::<dataframe::PyInsertOp>()?;
90+
m.add_class::<dataframe::PyDataFrameWriteOptions>()?;
8991
m.add_class::<dataframe::PyParquetColumnOptions>()?;
9092
m.add_class::<dataframe::PyParquetWriterOptions>()?;
9193
m.add_class::<udf::PyScalarUDF>()?;

0 commit comments

Comments
 (0)