Skip to content

Commit 3335a03

Browse files
PeterKeDerrtyler
authored andcommitted
Implement query builder
Signed-off-by: Peter Ke <[email protected]>
1 parent 1083c8c commit 3335a03

File tree

9 files changed

+175
-4
lines changed

9 files changed

+175
-4
lines changed

crates/core/src/delta_datafusion/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,9 +826,12 @@ impl TableProvider for DeltaTableProvider {
826826

827827
fn supports_filters_pushdown(
828828
&self,
829-
_filter: &[&Expr],
829+
filter: &[&Expr],
830830
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
831-
Ok(vec![TableProviderFilterPushDown::Inexact])
831+
Ok(filter
832+
.iter()
833+
.map(|_| TableProviderFilterPushDown::Inexact)
834+
.collect())
832835
}
833836

834837
fn statistics(&self) -> Option<Statistics> {

python/deltalake/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._internal import __version__ as __version__
33
from ._internal import rust_core_version as rust_core_version
44
from .data_catalog import DataCatalog as DataCatalog
5+
from .query import QueryBuilder as QueryBuilder
56
from .schema import DataType as DataType
67
from .schema import Field as Field
78
from .schema import Schema as Schema

python/deltalake/_internal.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,11 @@ class DeltaFileSystemHandler:
874874
) -> ObjectOutputStream:
875875
"""Open an output stream for sequential writing."""
876876

877+
class PyQueryBuilder:
878+
def __init__(self) -> None: ...
879+
def register(self, table_name: str, delta_table: RawDeltaTable) -> None: ...
880+
def execute(self, sql: str) -> List[pyarrow.RecordBatch]: ...
881+
877882
class DeltaDataChecker:
878883
def __init__(self, invariants: List[Tuple[str, str]]) -> None: ...
879884
def check_batch(self, batch: pyarrow.RecordBatch) -> None: ...

python/deltalake/query.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from __future__ import annotations
2+
3+
import warnings
4+
from typing import List
5+
6+
import pyarrow
7+
8+
from deltalake._internal import PyQueryBuilder
9+
from deltalake.table import DeltaTable
10+
from deltalake.warnings import ExperimentalWarning
11+
12+
13+
class QueryBuilder:
14+
def __init__(self) -> None:
15+
warnings.warn(
16+
"QueryBuilder is experimental and subject to change",
17+
category=ExperimentalWarning,
18+
)
19+
self._query_builder = PyQueryBuilder()
20+
21+
def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder:
22+
"""Add a table to the query builder."""
23+
self._query_builder.register(
24+
table_name=table_name,
25+
delta_table=delta_table._table,
26+
)
27+
return self
28+
29+
def execute(self, sql: str) -> List[pyarrow.RecordBatch]:
30+
"""Execute the query and return a list of record batches."""
31+
return self._query_builder.execute(sql)

python/deltalake/warnings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class ExperimentalWarning(Warning):
2+
pass

python/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use arrow_schema::ArrowError;
2+
use deltalake::datafusion::error::DataFusionError;
23
use deltalake::protocol::ProtocolError;
34
use deltalake::{errors::DeltaTableError, ObjectStoreError};
45
use pyo3::exceptions::{
@@ -79,6 +80,10 @@ fn checkpoint_to_py(err: ProtocolError) -> PyErr {
7980
}
8081
}
8182

83+
fn datafusion_to_py(err: DataFusionError) -> PyErr {
84+
DeltaError::new_err(err.to_string())
85+
}
86+
8287
#[derive(thiserror::Error, Debug)]
8388
pub enum PythonError {
8489
#[error("Error in delta table")]
@@ -89,6 +94,8 @@ pub enum PythonError {
8994
Arrow(#[from] ArrowError),
9095
#[error("Error in checkpoint")]
9196
Protocol(#[from] ProtocolError),
97+
#[error("Error in data fusion")]
98+
DataFusion(#[from] DataFusionError),
9299
}
93100

94101
impl From<PythonError> for pyo3::PyErr {
@@ -98,6 +105,7 @@ impl From<PythonError> for pyo3::PyErr {
98105
PythonError::ObjectStore(err) => object_store_to_py(err),
99106
PythonError::Arrow(err) => arrow_to_py(err),
100107
PythonError::Protocol(err) => checkpoint_to_py(err),
108+
PythonError::DataFusion(err) => datafusion_to_py(err),
101109
}
102110
}
103111
}

python/src/lib.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod error;
22
mod features;
33
mod filesystem;
44
mod merge;
5+
mod query;
56
mod schema;
67
mod utils;
78

@@ -20,12 +21,18 @@ use delta_kernel::expressions::Scalar;
2021
use delta_kernel::schema::StructField;
2122
use deltalake::arrow::compute::concat_batches;
2223
use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
24+
use deltalake::arrow::pyarrow::ToPyArrow;
2325
use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator};
2426
use deltalake::arrow::{self, datatypes::Schema as ArrowSchema};
2527
use deltalake::checkpoints::{cleanup_metadata, create_checkpoint};
28+
use deltalake::datafusion::datasource::provider_as_source;
29+
use deltalake::datafusion::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
2630
use deltalake::datafusion::physical_plan::ExecutionPlan;
27-
use deltalake::datafusion::prelude::SessionContext;
28-
use deltalake::delta_datafusion::DeltaDataChecker;
31+
use deltalake::datafusion::prelude::{DataFrame, SessionContext};
32+
use deltalake::delta_datafusion::{
33+
DataFusionMixins, DeltaDataChecker, DeltaScanConfigBuilder, DeltaSessionConfig,
34+
DeltaTableProvider,
35+
};
2936
use deltalake::errors::DeltaTableError;
3037
use deltalake::kernel::{
3138
scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, Transaction,
@@ -69,6 +76,7 @@ use crate::error::PythonError;
6976
use crate::features::TableFeatures;
7077
use crate::filesystem::FsConfig;
7178
use crate::merge::PyMergeBuilder;
79+
use crate::query::PyQueryBuilder;
7280
use crate::schema::{schema_to_pyobject, Field};
7381
use crate::utils::rt;
7482

@@ -2095,6 +2103,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
20952103
)?)?;
20962104
m.add_class::<RawDeltaTable>()?;
20972105
m.add_class::<PyMergeBuilder>()?;
2106+
m.add_class::<PyQueryBuilder>()?;
20982107
m.add_class::<RawDeltaTableMetaData>()?;
20992108
m.add_class::<PyDeltaDataChecker>()?;
21002109
m.add_class::<PyTransaction>()?;

python/src/query.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use std::sync::Arc;
2+
3+
use deltalake::{
4+
arrow::pyarrow::ToPyArrow,
5+
datafusion::prelude::SessionContext,
6+
delta_datafusion::{DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider},
7+
};
8+
use pyo3::prelude::*;
9+
10+
use crate::{error::PythonError, utils::rt, RawDeltaTable};
11+
12+
#[pyclass(module = "deltalake._internal")]
13+
pub(crate) struct PyQueryBuilder {
14+
_ctx: SessionContext,
15+
}
16+
17+
#[pymethods]
18+
impl PyQueryBuilder {
19+
#[new]
20+
pub fn new() -> Self {
21+
let config = DeltaSessionConfig::default().into();
22+
let _ctx = SessionContext::new_with_config(config);
23+
24+
PyQueryBuilder { _ctx }
25+
}
26+
27+
pub fn register(&self, table_name: &str, delta_table: &RawDeltaTable) -> PyResult<()> {
28+
let snapshot = delta_table._table.snapshot().map_err(PythonError::from)?;
29+
let log_store = delta_table._table.log_store();
30+
31+
let scan_config = DeltaScanConfigBuilder::default()
32+
.build(snapshot)
33+
.map_err(PythonError::from)?;
34+
35+
let provider = Arc::new(
36+
DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config)
37+
.map_err(PythonError::from)?,
38+
);
39+
40+
self._ctx
41+
.register_table(table_name, provider)
42+
.map_err(PythonError::from)?;
43+
44+
Ok(())
45+
}
46+
47+
pub fn execute(&self, py: Python, sql: &str) -> PyResult<PyObject> {
48+
let batches = py.allow_threads(|| {
49+
rt().block_on(async {
50+
let df = self._ctx.sql(sql).await?;
51+
df.collect().await
52+
})
53+
.map_err(PythonError::from)
54+
})?;
55+
56+
batches.to_pyarrow(py)
57+
}
58+
}

python/tests/test_table_read.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from deltalake._util import encode_partition_value
1111
from deltalake.exceptions import DeltaProtocolError
12+
from deltalake.query import QueryBuilder
1213
from deltalake.table import ProtocolVersions
1314
from deltalake.writer import write_deltalake
1415

@@ -946,3 +947,56 @@ def test_is_deltatable_with_storage_opts():
946947
"DELTA_DYNAMO_TABLE_NAME": "custom_table_name",
947948
}
948949
assert DeltaTable.is_deltatable(table_path, storage_options=storage_options)
950+
951+
952+
def test_read_query_builder():
953+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
954+
dt = DeltaTable(table_path)
955+
expected = {
956+
"value": ["4", "5", "6", "7"],
957+
"year": ["2021", "2021", "2021", "2021"],
958+
"month": ["4", "12", "12", "12"],
959+
"day": ["5", "4", "20", "20"],
960+
}
961+
actual = pa.Table.from_batches(
962+
QueryBuilder()
963+
.register("tbl", dt)
964+
.execute("SELECT * FROM tbl WHERE year >= 2021 ORDER BY value")
965+
).to_pydict()
966+
assert expected == actual
967+
968+
969+
def test_read_query_builder_join_multiple_tables(tmp_path):
970+
table_path = "../crates/test/tests/data/delta-0.8.0-date"
971+
dt1 = DeltaTable(table_path)
972+
973+
write_deltalake(
974+
tmp_path,
975+
pa.table(
976+
{
977+
"date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-12-31"],
978+
"value": ["a", "b", "c", "d"],
979+
}
980+
),
981+
)
982+
dt2 = DeltaTable(tmp_path)
983+
984+
expected = {
985+
"date": ["2021-01-01", "2021-01-02", "2021-01-03"],
986+
"dayOfYear": [1, 2, 3],
987+
"value": ["a", "b", "c"],
988+
}
989+
actual = pa.Table.from_batches(
990+
QueryBuilder()
991+
.register("tbl1", dt1)
992+
.register("tbl2", dt2)
993+
.execute(
994+
"""
995+
SELECT tbl2.date, tbl1.dayOfYear, tbl2.value
996+
FROM tbl1
997+
INNER JOIN tbl2 ON tbl1.date = tbl2.date
998+
ORDER BY tbl1.date
999+
"""
1000+
)
1001+
).to_pydict()
1002+
assert expected == actual

0 commit comments

Comments
 (0)