Skip to content

Commit c1e6bca

Browse files
committed
support unparser
1 parent d0315ff commit c1e6bca

File tree

5 files changed

+202
-0
lines changed

5 files changed

+202
-0
lines changed

python/datafusion/unparser.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""This module provides support for unparsing datafusion plans to SQL.
19+
20+
For additional information about unparsing, see https://docs.rs/datafusion-sql/latest/datafusion_sql/unparser/index.html
21+
"""
22+
23+
from ._internal import unparser as unparser_internal
24+
from .plan import LogicalPlan
25+
26+
27+
class Dialect:
28+
"""DataFusion data catalog."""
29+
30+
def __init__(self, dialect: unparser_internal.Dialect) -> None:
31+
"""This constructor is not typically called by the end user."""
32+
self.dialect = dialect
33+
34+
@staticmethod
35+
def mysql() -> "Dialect":
36+
"""Create a new MySQL dialect."""
37+
return Dialect(unparser_internal.Dialect.mysql())
38+
39+
@staticmethod
40+
def postgres() -> "Dialect":
41+
"""Create a new PostgreSQL dialect."""
42+
return Dialect(unparser_internal.Dialect.postgres())
43+
44+
@staticmethod
45+
def sqlite() -> "Dialect":
46+
"""Create a new SQLite dialect."""
47+
return Dialect(unparser_internal.Dialect.sqlite())
48+
49+
@staticmethod
50+
def duckdb() -> "Dialect":
51+
"""Create a new DuckDB dialect."""
52+
return Dialect(unparser_internal.Dialect.duckdb())
53+
54+
55+
class Unparser:
56+
"""DataFusion unparser."""
57+
58+
def __init__(self, dialect: Dialect) -> None:
59+
"""This constructor is not typically called by the end user."""
60+
self.unparser = unparser_internal.Unparser(dialect.dialect)
61+
62+
def plan_to_sql(self, plan: LogicalPlan) -> str:
63+
"""Convert a logical plan to a SQL string."""
64+
return self.unparser.plan_to_sql(plan._raw_plan)
65+
66+
__all__ = [
67+
"Dialect",
68+
"Unparser",
69+
]

python/tests/test_unparser.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from datafusion.context import SessionContext
2+
from datafusion.unparser import Dialect, Unparser
3+
4+
5+
def test_unparser():
6+
ctx = SessionContext()
7+
df = ctx.sql("SELECT 1")
8+
for dialect in [
9+
Dialect.mysql(),
10+
Dialect.postgres(),
11+
Dialect.sqlite(),
12+
Dialect.duckdb(),
13+
]:
14+
unparser = Unparser(dialect)
15+
sql = unparser.plan_to_sql(df.logical_plan())
16+
assert sql == "SELECT 1"

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub mod pyarrow_util;
5252
mod record_batch;
5353
pub mod sql;
5454
pub mod store;
55+
pub mod unparser;
5556

5657
#[cfg(feature = "substrait")]
5758
pub mod substrait;
@@ -103,6 +104,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
103104
expr::init_module(&expr)?;
104105
m.add_submodule(&expr)?;
105106

107+
let unparser = PyModule::new(py, "unparser")?;
108+
unparser::init_module(&unparser)?;
109+
m.add_submodule(&unparser)?;
110+
106111
// Register the functions as a submodule
107112
let funcs = PyModule::new(py, "functions")?;
108113
functions::init_module(&funcs)?;

src/unparser/dialect.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use datafusion::sql::unparser::dialect::{
21+
DefaultDialect, Dialect, DuckDBDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
22+
};
23+
use pyo3::prelude::*;
24+
25+
#[pyclass(name = "Dialect", module = "datafusion.unparser", subclass)]
26+
#[derive(Clone)]
27+
pub struct PyDialect {
28+
pub dialect: Arc<dyn Dialect>,
29+
}
30+
31+
#[pymethods]
32+
impl PyDialect {
33+
#[staticmethod]
34+
pub fn default() -> Self {
35+
Self {
36+
dialect: Arc::new(DefaultDialect {}),
37+
}
38+
}
39+
#[staticmethod]
40+
pub fn postgres() -> Self {
41+
Self {
42+
dialect: Arc::new(PostgreSqlDialect {}),
43+
}
44+
}
45+
#[staticmethod]
46+
pub fn mysql() -> Self {
47+
Self {
48+
dialect: Arc::new(MySqlDialect {}),
49+
}
50+
}
51+
#[staticmethod]
52+
pub fn sqlite() -> Self {
53+
Self {
54+
dialect: Arc::new(SqliteDialect {}),
55+
}
56+
}
57+
#[staticmethod]
58+
pub fn duckdb() -> Self {
59+
Self {
60+
dialect: Arc::new(DuckDBDialect::new()),
61+
}
62+
}
63+
}

src/unparser/mod.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
mod dialect;
2+
3+
use std::sync::Arc;
4+
5+
use datafusion::sql::unparser::{dialect::Dialect, Unparser};
6+
use dialect::PyDialect;
7+
use pyo3::{exceptions::PyValueError, prelude::*};
8+
9+
use crate::sql::logical::PyLogicalPlan;
10+
11+
#[pyclass(name = "Unparser", module = "datafusion.unparser", subclass)]
12+
#[derive(Clone)]
13+
pub struct PyUnparser {
14+
dialect: Arc<dyn Dialect>,
15+
pretty: bool,
16+
}
17+
18+
#[pymethods]
19+
impl PyUnparser {
20+
#[new]
21+
pub fn new(dialect: PyDialect) -> Self {
22+
Self {
23+
dialect: dialect.dialect.clone(),
24+
pretty: false,
25+
}
26+
}
27+
28+
pub fn plan_to_sql(&self, plan: &PyLogicalPlan) -> PyResult<String> {
29+
let mut unparser = Unparser::new(self.dialect.as_ref());
30+
unparser = unparser.with_pretty(self.pretty);
31+
let sql = unparser
32+
.plan_to_sql(&plan.plan())
33+
.map_err(|e| PyValueError::new_err(e.to_string()))?;
34+
Ok(sql.to_string())
35+
}
36+
37+
pub fn with_pretty(&self, pretty: bool) -> Self {
38+
Self {
39+
dialect: self.dialect.clone(),
40+
pretty,
41+
}
42+
}
43+
}
44+
45+
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
46+
m.add_class::<PyUnparser>()?;
47+
m.add_class::<PyDialect>()?;
48+
Ok(())
49+
}

0 commit comments

Comments
 (0)