Skip to content

Commit aedeaaa

Browse files
committed
init
1 parent 8f4232f commit aedeaaa

File tree

4 files changed

+144
-2
lines changed

4 files changed

+144
-2
lines changed

python/deltalake/_internal.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ class RawDeltaTable:
221221
starting_timestamp: Optional[str] = None,
222222
ending_timestamp: Optional[str] = None,
223223
) -> pyarrow.RecordBatchReader: ...
224+
def datafusion_read(
225+
self,
226+
predicate: Optional[str] = None,
227+
columns: Optional[List[str]] = None,
228+
) -> None: ...
224229

225230
def rust_core_version() -> str: ...
226231
def write_new_deltalake(

python/deltalake/table.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,13 @@ def repair(
14171417
)
14181418
return json.loads(metrics)
14191419

1420+
def datafusion_read(
1421+
self,
1422+
predicate: Optional[str] = None,
1423+
columns: Optional[List[str]] = None,
1424+
) -> List[pyarrow.RecordBatch]:
1425+
return self._table.datafusion_read(predicate, columns)
1426+
14201427

14211428
class TableMerger:
14221429
"""API for various table `MERGE` commands."""

python/src/lib.rs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod utils;
88
use std::collections::{HashMap, HashSet};
99
use std::future::IntoFuture;
1010
use std::str::FromStr;
11+
use std::sync::Arc;
1112
use std::time;
1213
use std::time::{SystemTime, UNIX_EPOCH};
1314

@@ -17,12 +18,18 @@ use delta_kernel::expressions::Scalar;
1718
use delta_kernel::schema::StructField;
1819
use deltalake::arrow::compute::concat_batches;
1920
use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
21+
use deltalake::arrow::pyarrow::ToPyArrow;
2022
use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator};
2123
use deltalake::arrow::{self, datatypes::Schema as ArrowSchema};
2224
use deltalake::checkpoints::{cleanup_metadata, create_checkpoint};
25+
use deltalake::datafusion::datasource::provider_as_source;
26+
use deltalake::datafusion::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
2327
use deltalake::datafusion::physical_plan::ExecutionPlan;
24-
use deltalake::datafusion::prelude::SessionContext;
25-
use deltalake::delta_datafusion::DeltaDataChecker;
28+
use deltalake::datafusion::prelude::{DataFrame, SessionContext};
29+
use deltalake::delta_datafusion::{
30+
DataFusionMixins, DeltaDataChecker, DeltaScanConfigBuilder, DeltaSessionConfig,
31+
DeltaTableProvider,
32+
};
2633
use deltalake::errors::DeltaTableError;
2734
use deltalake::kernel::{
2835
scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType,
@@ -1232,6 +1239,65 @@ impl RawDeltaTable {
12321239
self._table.state = table.state;
12331240
Ok(serde_json::to_string(&metrics).unwrap())
12341241
}
1242+
1243+
#[pyo3(signature = (predicate = None, columns = None))]
1244+
pub fn datafusion_read(
1245+
&self,
1246+
py: Python,
1247+
predicate: Option<String>,
1248+
columns: Option<Vec<String>>,
1249+
) -> PyResult<PyObject> {
1250+
let batches = py.allow_threads(|| -> PyResult<_> {
1251+
let snapshot = self._table.snapshot().map_err(PythonError::from)?;
1252+
let log_store = self._table.log_store();
1253+
1254+
let scan_config = DeltaScanConfigBuilder::default()
1255+
.with_parquet_pushdown(false)
1256+
.build(snapshot)
1257+
.map_err(PythonError::from)?;
1258+
1259+
let provider = Arc::new(
1260+
DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config)
1261+
.map_err(PythonError::from)?,
1262+
);
1263+
let source = provider_as_source(provider);
1264+
1265+
let config = DeltaSessionConfig::default().into();
1266+
let session = SessionContext::new_with_config(config);
1267+
let state = session.state();
1268+
1269+
let maybe_filter = predicate
1270+
.map(|predicate| snapshot.parse_predicate_expression(predicate, &state))
1271+
.transpose()
1272+
.map_err(PythonError::from)?;
1273+
1274+
let filters = match &maybe_filter {
1275+
Some(filter) => vec![filter.clone()],
1276+
None => vec![],
1277+
};
1278+
1279+
let plan = LogicalPlanBuilder::scan_with_filters(UNNAMED_TABLE, source, None, filters)
1280+
.unwrap()
1281+
.build()
1282+
.unwrap();
1283+
1284+
let mut df = DataFrame::new(state, plan);
1285+
1286+
if let Some(filter) = maybe_filter {
1287+
df = df.filter(filter).unwrap();
1288+
}
1289+
1290+
if let Some(columns) = columns {
1291+
df = df
1292+
.select_columns(&columns.iter().map(String::as_str).collect::<Vec<_>>())
1293+
.unwrap();
1294+
}
1295+
1296+
Ok(rt().block_on(async { df.collect().await }).unwrap())
1297+
})?;
1298+
1299+
batches.to_pyarrow(py)
1300+
}
12351301
}
12361302

12371303
fn set_post_commithook_properties(

python/tests/test_table_read.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,3 +946,67 @@ def test_is_deltatable_with_storage_opts():
946946
"DELTA_DYNAMO_TABLE_NAME": "custom_table_name",
947947
}
948948
assert DeltaTable.is_deltatable(table_path, storage_options=storage_options)
949+
950+
951+
def test_datafusion_read_table():
952+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
953+
dt = DeltaTable(table_path)
954+
expected = {
955+
"value": ["1", "2", "3", "4", "5", "6", "7"],
956+
"year": ["2020", "2020", "2020", "2021", "2021", "2021", "2021"],
957+
"month": ["1", "2", "2", "4", "12", "12", "12"],
958+
"day": ["1", "3", "5", "5", "4", "20", "20"],
959+
}
960+
actual = pa.Table.from_batches(dt.datafusion_read()).sort_by("value").to_pydict()
961+
assert expected == actual
962+
963+
964+
def test_datafusion_read_table_with_columns():
965+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
966+
dt = DeltaTable(table_path)
967+
expected = {
968+
"value": ["1", "2", "3", "4", "5", "6", "7"],
969+
"day": ["1", "3", "5", "5", "4", "20", "20"],
970+
}
971+
actual = (
972+
pa.Table.from_batches(dt.datafusion_read(columns=["value", "day"]))
973+
.sort_by("value")
974+
.to_pydict()
975+
)
976+
assert expected == actual
977+
978+
979+
def test_datafusion_read_with_filter_on_partitioned_column():
980+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
981+
dt = DeltaTable(table_path)
982+
expected = {
983+
"value": ["1", "2", "3"],
984+
"year": ["2020", "2020", "2020"],
985+
"month": ["1", "2", "2"],
986+
"day": ["1", "3", "5"],
987+
}
988+
actual = (
989+
pa.Table.from_batches(dt.datafusion_read(predicate="year = '2020'"))
990+
.sort_by("value")
991+
.to_pydict()
992+
)
993+
assert expected == actual
994+
995+
996+
def test_datafusion_read_with_filter_on_multiple_columns():
997+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
998+
dt = DeltaTable(table_path)
999+
expected = {
1000+
"value": ["4", "5"],
1001+
"year": ["2021", "2021"],
1002+
"month": ["4", "12"],
1003+
"day": ["5", "4"],
1004+
}
1005+
actual = (
1006+
pa.Table.from_batches(
1007+
dt.datafusion_read(predicate="year = '2021' and value < '6'")
1008+
)
1009+
.sort_by("value")
1010+
.to_pydict()
1011+
)
1012+
assert expected == actual

0 commit comments

Comments
 (0)