Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ impl SessionState {
let mut statements = DFParserBuilder::new(sql)
.with_dialect(dialect.as_ref())
.with_recursion_limit(recursion_limit)
.build()?
.parse_statements()?;

if statements.len() > 1 {
Expand Down Expand Up @@ -534,6 +535,7 @@ impl SessionState {
let expr = DFParserBuilder::new(sql)
.with_dialect(dialect.as_ref())
.with_recursion_limit(recursion_limit)
.build()?
.parse_expr()?;

Ok(expr)
Expand Down
90 changes: 71 additions & 19 deletions datafusion/sql/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,46 @@ pub struct DFParser<'a> {
const DEFAULT_RECURSION_LIMIT: usize = 50;
const DEFAULT_DIALECT: GenericDialect = GenericDialect {};

/// Builder for [`DFParser`]
///
/// # Example: Create and Parse SQL statements
/// ```
/// # use datafusion_sql::parser::DFParserBuilder;
/// # use datafusion_common::Result;
/// # fn test() -> Result<()> {
/// let mut parser = DFParserBuilder::new("SELECT * FROM foo; SELECT 1 + 2")
/// .build()?;
/// // parse the SQL into DFStatements
/// let statements = parser.parse_statements()?;
/// assert_eq!(statements.len(), 2);
/// # Ok(())
/// # }
/// ```
///
/// # Example: Create and Parse expression with a different dialect
/// ```
/// # use datafusion_sql::parser::DFParserBuilder;
/// # use datafusion_common::Result;
/// # use datafusion_sql::sqlparser::dialect::MySqlDialect;
/// # use datafusion_sql::sqlparser::ast::Expr;
/// # fn test() -> Result<()> {
/// let dialect = MySqlDialect{}; // Parse using MySQL dialect
/// let mut parser = DFParserBuilder::new("1 + 2")
/// .with_dialect(&dialect)
/// .build()?;
/// // parse 1+2 into an sqlparser::ast::Expr
/// let res = parser.parse_expr()?;
/// assert!(matches!(res.expr, Expr::BinaryOp {..}));
/// # Ok(())
/// # }
/// ```
pub struct DFParserBuilder<'a> {
pub sql: &'a str,
pub dialect: &'a dyn Dialect,
pub recursion_limit: usize,
/// The SQL string to parse
sql: &'a str,
/// The Dialect to use (defaults to [`GenericDialect`]
dialect: &'a dyn Dialect,
/// The recursion limit while parsing
recursion_limit: usize,
}

impl<'a> DFParserBuilder<'a> {
Expand All @@ -290,29 +326,18 @@ impl<'a> DFParserBuilder<'a> {
}
}

/// Adjust the parser builder's dialect
/// Adjust the parser builder's dialect. Defaults to [`GenericDialect`]
pub fn with_dialect(mut self, dialect: &'a dyn Dialect) -> Self {
self.dialect = dialect;
self
}

/// Adjust the recursion limit of sql parsing
/// Adjust the recursion limit of sql parsing. Defaults to 50
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
self.recursion_limit = recursion_limit;

self
}

pub fn parse_statements(self) -> Result<VecDeque<Statement>, ParserError> {
let mut parser = self.build()?;
parser.parse_statements()
}

pub fn parse_expr(self) -> Result<ExprWithAlias, ParserError> {
let mut parser = self.build()?;
parser.parse_expr()
}

pub fn build(self) -> Result<DFParser<'a>, ParserError> {
let mut tokenizer = Tokenizer::new(self.dialect, self.sql);
let tokens = tokenizer.tokenize_with_location()?;
Expand All @@ -326,6 +351,19 @@ impl<'a> DFParserBuilder<'a> {
}

impl<'a> DFParser<'a> {
#[deprecated(since = "46.0.0", note = "DFParserBuilder")]
pub fn new(sql: &'a str) -> Result<Self, ParserError> {
DFParserBuilder::new(sql).build()
}

#[deprecated(since = "46.0.0", note = "DFParserBuilder")]
pub fn new_with_dialect(
sql: &'a str,
dialect: &'a dyn Dialect,
) -> Result<Self, ParserError> {
DFParserBuilder::new(sql).with_dialect(dialect).build()
}

/// Parse a sql string into one or [`Statement`]s using the
/// [`GenericDialect`].
pub fn parse_sql(sql: &'a str) -> Result<VecDeque<Statement>, ParserError> {
Expand Down Expand Up @@ -925,6 +963,7 @@ impl<'a> DFParser<'a> {
#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::assert_contains;
use sqlparser::ast::Expr::Identifier;
use sqlparser::ast::{BinaryOperator, DataType, Expr, Ident};
use sqlparser::dialect::SnowflakeDialect;
Expand Down Expand Up @@ -1662,10 +1701,23 @@ mod tests {
fn test_recursion_limit() {
let sql = "SELECT 1 OR 2";

assert!(DFParserBuilder::new(sql).parse_statements().is_ok());
assert!(DFParserBuilder::new(sql)
// Expect parse to succeed
DFParserBuilder::new(sql)
.build()
.unwrap()
.parse_statements()
.unwrap();

let err = DFParserBuilder::new(sql)
.with_recursion_limit(1)
.build()
.unwrap()
.parse_statements()
.is_err());
.unwrap_err();

assert_contains!(
err.to_string(),
"sql parser error: recursion limit exceeded"
);
}
}