Skip to content
6 changes: 6 additions & 0 deletions python/datafusion/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@
SqlTable = common_internal.SqlTable
SqlType = common_internal.SqlType
SqlView = common_internal.SqlView
TableType = common_internal.TableType
TableSource = common_internal.TableSource
Constraints = common_internal.Constraints

__all__ = [
"Constraints",
"DFSchema",
"DataType",
"DataTypeMap",
Expand All @@ -47,6 +51,8 @@
"SqlTable",
"SqlType",
"SqlView",
"TableSource",
"TableType",
]


Expand Down
50 changes: 50 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,29 @@
Case = expr_internal.Case
Cast = expr_internal.Cast
Column = expr_internal.Column
CopyTo = expr_internal.CopyTo
CreateCatalog = expr_internal.CreateCatalog
CreateCatalogSchema = expr_internal.CreateCatalogSchema
CreateExternalTable = expr_internal.CreateExternalTable
CreateFunction = expr_internal.CreateFunction
CreateFunctionBody = expr_internal.CreateFunctionBody
CreateIndex = expr_internal.CreateIndex
CreateMemoryTable = expr_internal.CreateMemoryTable
CreateView = expr_internal.CreateView
Deallocate = expr_internal.Deallocate
DescribeTable = expr_internal.DescribeTable
Distinct = expr_internal.Distinct
DmlStatement = expr_internal.DmlStatement
DropCatalogSchema = expr_internal.DropCatalogSchema
DropFunction = expr_internal.DropFunction
DropTable = expr_internal.DropTable
DropView = expr_internal.DropView
EmptyRelation = expr_internal.EmptyRelation
Execute = expr_internal.Execute
Exists = expr_internal.Exists
Explain = expr_internal.Explain
Extension = expr_internal.Extension
FileType = expr_internal.FileType
Filter = expr_internal.Filter
GroupingSet = expr_internal.GroupingSet
Join = expr_internal.Join
Expand All @@ -83,21 +98,31 @@
Literal = expr_internal.Literal
Negative = expr_internal.Negative
Not = expr_internal.Not
OperateFunctionArg = expr_internal.OperateFunctionArg
Partitioning = expr_internal.Partitioning
Placeholder = expr_internal.Placeholder
Prepare = expr_internal.Prepare
Projection = expr_internal.Projection
RecursiveQuery = expr_internal.RecursiveQuery
Repartition = expr_internal.Repartition
ScalarSubquery = expr_internal.ScalarSubquery
ScalarVariable = expr_internal.ScalarVariable
SetVariable = expr_internal.SetVariable
SimilarTo = expr_internal.SimilarTo
Sort = expr_internal.Sort
Subquery = expr_internal.Subquery
SubqueryAlias = expr_internal.SubqueryAlias
TableScan = expr_internal.TableScan
TransactionAccessMode = expr_internal.TransactionAccessMode
TransactionConclusion = expr_internal.TransactionConclusion
TransactionEnd = expr_internal.TransactionEnd
TransactionIsolationLevel = expr_internal.TransactionIsolationLevel
TransactionStart = expr_internal.TransactionStart
TryCast = expr_internal.TryCast
Union = expr_internal.Union
Unnest = expr_internal.Unnest
UnnestExpr = expr_internal.UnnestExpr
Values = expr_internal.Values
WindowExpr = expr_internal.WindowExpr

__all__ = [
Expand All @@ -111,15 +136,30 @@
"CaseBuilder",
"Cast",
"Column",
"CopyTo",
"CreateCatalog",
"CreateCatalogSchema",
"CreateExternalTable",
"CreateFunction",
"CreateFunctionBody",
"CreateIndex",
"CreateMemoryTable",
"CreateView",
"Deallocate",
"DescribeTable",
"Distinct",
"DmlStatement",
"DropCatalogSchema",
"DropFunction",
"DropTable",
"DropView",
"EmptyRelation",
"Execute",
"Exists",
"Explain",
"Expr",
"Extension",
"FileType",
"Filter",
"GroupingSet",
"ILike",
Expand All @@ -142,22 +182,32 @@
"Literal",
"Negative",
"Not",
"OperateFunctionArg",
"Partitioning",
"Placeholder",
"Prepare",
"Projection",
"RecursiveQuery",
"Repartition",
"ScalarSubquery",
"ScalarVariable",
"SetVariable",
"SimilarTo",
"Sort",
"SortExpr",
"Subquery",
"SubqueryAlias",
"TableScan",
"TransactionAccessMode",
"TransactionConclusion",
"TransactionEnd",
"TransactionIsolationLevel",
"TransactionStart",
"TryCast",
"Union",
"Unnest",
"UnnestExpr",
"Values",
"Window",
"WindowExpr",
"WindowFrame",
Expand Down
86 changes: 86 additions & 0 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@
AggregateFunction,
BinaryExpr,
Column,
CopyTo,
CreateIndex,
DescribeTable,
DmlStatement,
DropCatalogSchema,
Filter,
Limit,
Literal,
Projection,
RecursiveQuery,
Sort,
TableScan,
TransactionEnd,
TransactionStart,
Values,
)


Expand Down Expand Up @@ -247,3 +256,80 @@ def test_fill_null(df):
assert result.column(0) == pa.array([1, 2, 100])
assert result.column(1) == pa.array([4, 25, 6])
assert result.column(2) == pa.array([1234, 1234, 8])


def test_copy_to():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
df = ctx.sql("COPY foo TO bar STORED AS CSV")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, CopyTo)


def test_create_index():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
plan = ctx.sql("create index idx on foo (a)").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, CreateIndex)


def test_describe_table():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
plan = ctx.sql("describe foo").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, DescribeTable)


def test_dml_statement():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
plan = ctx.sql("insert into foo values (1, 2)").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, DmlStatement)


def drop_catalog_schema():
ctx = SessionContext()
plan = ctx.sql("drop schema cat").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, DropCatalogSchema)


def test_recursive_query():
ctx = SessionContext()
plan = ctx.sql(
"""
WITH RECURSIVE cte AS (
SELECT 1 as n
UNION ALL
SELECT n + 1 FROM cte WHERE n < 5
)
SELECT * FROM cte;
"""
).logical_plan()
plan = plan.inputs()[0].inputs()[0].to_variant()
assert isinstance(plan, RecursiveQuery)


def test_values():
ctx = SessionContext()
plan = ctx.sql("values (1, 'foo'), (2, 'bar')").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Values)


def test_transaction_start():
ctx = SessionContext()
plan = ctx.sql("START TRANSACTION").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, TransactionStart)


def test_transaction_end():
ctx = SessionContext()
plan = ctx.sql("COMMIT").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, TransactionEnd)
3 changes: 3 additions & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,8 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<schema::SqlView>()?;
m.add_class::<schema::SqlStatistics>()?;
m.add_class::<function::SqlFunction>()?;
m.add_class::<schema::PyTableType>()?;
m.add_class::<schema::PyTableSource>()?;
m.add_class::<schema::PyConstraints>()?;
Ok(())
}
89 changes: 89 additions & 0 deletions src/common/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@
// specific language governing permissions and limitations
// under the License.

use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
use std::{any::Any, borrow::Cow};

use arrow::datatypes::Schema;
use arrow::pyarrow::PyArrowType;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::Constraints;
use datafusion::datasource::TableType;
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
use pyo3::prelude::*;

use datafusion::logical_expr::utils::split_conjunction;

use crate::sql::logical::PyLogicalPlan;

use super::{data_type::DataTypeMap, function::SqlFunction};

#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
Expand Down Expand Up @@ -218,3 +226,84 @@ impl SqlStatistics {
self.row_count
}
}

#[pyclass(name = "Constraints", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyConstraints {
pub constraints: Constraints,
}

impl From<PyConstraints> for Constraints {
fn from(constraints: PyConstraints) -> Self {
constraints.constraints
}
}

impl From<Constraints> for PyConstraints {
fn from(constraints: Constraints) -> Self {
PyConstraints { constraints }
}
}

impl Display for PyConstraints {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "Constraints: {:?}", self.constraints)
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[pyclass(eq, eq_int, name = "TableType", module = "datafusion.common")]
pub enum PyTableType {
Base,
View,
Temporary,
}

impl From<PyTableType> for datafusion::logical_expr::TableType {
fn from(table_type: PyTableType) -> Self {
match table_type {
PyTableType::Base => datafusion::logical_expr::TableType::Base,
PyTableType::View => datafusion::logical_expr::TableType::View,
PyTableType::Temporary => datafusion::logical_expr::TableType::Temporary,
}
}
}

impl From<TableType> for PyTableType {
fn from(table_type: TableType) -> Self {
match table_type {
datafusion::logical_expr::TableType::Base => PyTableType::Base,
datafusion::logical_expr::TableType::View => PyTableType::View,
datafusion::logical_expr::TableType::Temporary => PyTableType::Temporary,
}
}
}

#[pyclass(name = "TableSource", module = "datafusion.common", subclass)]
#[derive(Clone)]
pub struct PyTableSource {
pub table_source: Arc<dyn TableSource>,
}

#[pymethods]
impl PyTableSource {
pub fn schema(&self) -> PyArrowType<Schema> {
(*self.table_source.schema()).clone().into()
}

pub fn constraints(&self) -> Option<PyConstraints> {
self.table_source.constraints().map(|c| PyConstraints {
constraints: c.clone(),
})
}

pub fn table_type(&self) -> PyTableType {
self.table_source.table_type().into()
}

pub fn get_logical_plan(&self) -> Option<PyLogicalPlan> {
self.table_source
.get_logical_plan()
.map(|plan| PyLogicalPlan::new(plan.into_owned()))
}
}
Loading