Skip to content

Commit 09f05fd

Browse files
ion-elgrecoLiam Brannigan
authored andcommitted
feat: streamed write execution except cdf
Signed-off-by: Ion Koutsouris <[email protected]> Signed-off-by: Liam Brannigan <[email protected]>
1 parent ea1d290 commit 09f05fd

File tree

4 files changed

+108
-8
lines changed

4 files changed

+108
-8
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//! This module contains helper functions to create a LazyTableProvider from an ArrowArrayStreamReader
2+
3+
use crate::DeltaResult;
4+
use arrow::ffi_stream::ArrowArrayStreamReader;
5+
use datafusion::catalog::TableProvider;
6+
use datafusion::physical_plan::memory::LazyBatchGenerator;
7+
use delta_datafusion::LazyTableProvider;
8+
use parking_lot::RwLock;
9+
use std::fmt::{self};
10+
use std::sync::{Arc, Mutex};
11+
12+
use crate::delta_datafusion;
13+
14+
/// Convert an [ArrowArrayStreamReader] into a [LazyTableProvider]
15+
pub fn to_lazy_table(source: ArrowArrayStreamReader) -> DeltaResult<Arc<dyn TableProvider>> {
16+
use arrow::array::RecordBatchReader;
17+
let schema = source.schema();
18+
let arrow_stream: Arc<Mutex<ArrowArrayStreamReader>> = Arc::new(Mutex::new(source));
19+
let arrow_stream_batch_generator: Arc<RwLock<dyn LazyBatchGenerator>> =
20+
Arc::new(RwLock::new(ArrowStreamBatchGenerator::new(arrow_stream)));
21+
22+
Ok(Arc::new(LazyTableProvider::try_new(
23+
schema.clone(),
24+
vec![arrow_stream_batch_generator],
25+
)?))
26+
}
27+
28+
#[derive(Debug)]
29+
pub(crate) struct ArrowStreamBatchGenerator {
30+
pub array_stream: Arc<Mutex<ArrowArrayStreamReader>>,
31+
}
32+
33+
impl fmt::Display for ArrowStreamBatchGenerator {
34+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35+
write!(
36+
f,
37+
"ArrowStreamBatchGenerator {{ array_stream: {:?} }}",
38+
self.array_stream
39+
)
40+
}
41+
}
42+
43+
impl ArrowStreamBatchGenerator {
44+
pub fn new(array_stream: Arc<Mutex<ArrowArrayStreamReader>>) -> Self {
45+
Self { array_stream }
46+
}
47+
}
48+
49+
impl LazyBatchGenerator for ArrowStreamBatchGenerator {
50+
fn generate_next_batch(
51+
&mut self,
52+
) -> datafusion::error::Result<Option<arrow::array::RecordBatch>> {
53+
let mut stream_reader = self.array_stream.lock().map_err(|_| {
54+
datafusion::error::DataFusionError::Execution(
55+
"Failed to lock the ArrowArrayStreamReader".to_string(),
56+
)
57+
})?;
58+
59+
match stream_reader.next() {
60+
Some(Ok(record_batch)) => Ok(Some(record_batch)),
61+
Some(Err(err)) => Err(datafusion::error::DataFusionError::ArrowError(err, None)),
62+
None => Ok(None), // End of stream
63+
}
64+
}
65+
}

crates/core/src/operations/write/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
pub mod configs;
2727
pub(crate) mod execution;
2828
pub(crate) mod generated_columns;
29+
pub mod lazy;
2930
pub(crate) mod schema_evolution;
3031

3132
use arrow_schema::Schema;
@@ -44,7 +45,7 @@ use datafusion::datasource::MemTable;
4445
use datafusion::execution::context::{SessionContext, SessionState};
4546
use datafusion::prelude::DataFrame;
4647
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
47-
use datafusion_expr::{cast, col, lit, Expr, LogicalPlan, UNNAMED_TABLE};
48+
use datafusion_expr::{cast, lit, Expr, LogicalPlan};
4849
use execution::{prepare_predicate_actions, write_execution_plan_with_predicate};
4950
use futures::future::BoxFuture;
5051
use parquet::file::properties::WriterProperties;
@@ -438,7 +439,7 @@ impl std::future::IntoFuture for WriteBuilder {
438439
.unwrap_or_default();
439440

440441
let mut schema_drift = false;
441-
let mut source = DataFrame::new(state.clone(), this.input.unwrap().as_ref().clone());
442+
let source = DataFrame::new(state.clone(), this.input.unwrap().as_ref().clone());
442443

443444
// Add missing generated columns to source_df
444445
let (mut source, missing_generated_columns) =

python/src/lib.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ use deltalake::operations::transaction::{
5050
};
5151
use deltalake::operations::update::UpdateBuilder;
5252
use deltalake::operations::vacuum::VacuumBuilder;
53+
use deltalake::operations::write::WriteBuilder;
5354
use deltalake::operations::{collect_sendable_stream, CustomExecuteHandler};
5455
use deltalake::parquet::basic::Compression;
5556
use deltalake::parquet::errors::ParquetError;
@@ -2151,7 +2152,6 @@ fn write_to_deltalake(
21512152
post_commithook_properties: Option<PyPostCommitHookProperties>,
21522153
) -> PyResult<()> {
21532154
py.allow_threads(|| {
2154-
let batches = data.0.map(|batch| batch.unwrap()).collect::<Vec<_>>();
21552155
let save_mode = mode.parse().map_err(PythonError::from)?;
21562156

21572157
let options = storage_options.clone().unwrap_or_default();
@@ -2164,7 +2164,33 @@ fn write_to_deltalake(
21642164
.map_err(PythonError::from)?
21652165
};
21662166

2167-
let mut builder = table.write(batches).with_save_mode(save_mode);
2167+
let dont_be_so_lazy = match table.0.state.as_ref() {
2168+
Some(state) => state.table_config().enable_change_data_feed(),
2169+
// You don't have state somehow, so I guess it's okay to be lazy.
2170+
_ => false,
2171+
};
2172+
2173+
let mut builder =
2174+
WriteBuilder::new(table.0.log_store(), table.0.state).with_save_mode(save_mode);
2175+
2176+
if dont_be_so_lazy {
2177+
debug!(
2178+
"write_to_deltalake() is not able to lazily perform a write, collecting batches"
2179+
);
2180+
builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap()));
2181+
} else {
2182+
use deltalake::datafusion::datasource::provider_as_source;
2183+
use deltalake::datafusion::logical_expr::LogicalPlanBuilder;
2184+
use deltalake::operations::write::lazy::to_lazy_table;
2185+
let table_provider = to_lazy_table(data.0).map_err(PythonError::from)?;
2186+
2187+
let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None)
2188+
.map_err(PythonError::from)?
2189+
.build()
2190+
.map_err(PythonError::from)?;
2191+
builder = builder.with_input_execution_plan(Arc::new(plan));
2192+
}
2193+
21682194
if let Some(schema_mode) = schema_mode {
21692195
builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?);
21702196
}

python/tests/test_writer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ def test_partition_overwrite(
10351035
tmp_path, sample_data, mode="overwrite", predicate=f"p2 < {filter_string}"
10361036
)
10371037

1038+
10381039
@pytest.fixture()
10391040
def sample_data_for_partitioning() -> pa.Table:
10401041
return pa.table(
@@ -1590,9 +1591,13 @@ def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine):
15901591

15911592
def test_empty(existing_table: DeltaTable):
15921593
schema = existing_table.schema().to_pyarrow()
1594+
expected = existing_table.to_pyarrow_table()
15931595
empty_table = pa.Table.from_pylist([], schema=schema)
1594-
with pytest.raises(DeltaError, match="No data source supplied to write command"):
1595-
write_deltalake(existing_table, empty_table, mode="append", engine="rust")
1596+
write_deltalake(existing_table, empty_table, mode="append", engine="rust")
1597+
1598+
existing_table.update_incremental()
1599+
assert existing_table.version() == 1
1600+
assert expected == existing_table.to_pyarrow_table()
15961601

15971602

15981603
def test_rust_decimal_cast(tmp_path: pathlib.Path):
@@ -1815,8 +1820,11 @@ def test_roundtrip_cdc_evolution(tmp_path: pathlib.Path):
18151820
def test_empty_dataset_write(tmp_path: pathlib.Path, sample_data: pa.Table):
18161821
empty_arrow_table = sample_data.schema.empty_table()
18171822
empty_dataset = dataset(empty_arrow_table)
1818-
with pytest.raises(DeltaError, match="No data source supplied to write command"):
1819-
write_deltalake(tmp_path, empty_dataset, mode="append")
1823+
write_deltalake(tmp_path, empty_dataset, mode="append")
1824+
dt = DeltaTable(tmp_path)
1825+
1826+
new_dataset = dt.to_pyarrow_dataset()
1827+
assert new_dataset.count_rows() == 0
18201828

18211829

18221830
@pytest.mark.pandas

0 commit comments

Comments
 (0)