Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,8 @@ pub enum Expr {
/// not `< 0` nor `1, 2, 3` as allowed in a `<simple when clause>` per
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
Case {
case_token: AttachedToken,
end_token: AttachedToken,
operand: Option<Box<Expr>>,
conditions: Vec<CaseWhen>,
else_result: Option<Box<Expr>>,
Expand Down Expand Up @@ -1675,6 +1677,8 @@ impl fmt::Display for Expr {
}
Expr::Function(fun) => fun.fmt(f),
Expr::Case {
case_token: _,
end_token: _,
operand,
conditions,
else_result,
Expand Down
22 changes: 14 additions & 8 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1566,18 +1566,24 @@ impl Spanned for Expr {
),
Expr::Prefixed { value, .. } => value.span(),
Expr::Case {
case_token,
end_token,
operand,
conditions,
else_result,
} => union_spans(
operand
.as_ref()
.map(|i| i.span())
.into_iter()
.chain(conditions.iter().flat_map(|case_when| {
[case_when.condition.span(), case_when.result.span()]
}))
.chain(else_result.as_ref().map(|i| i.span())),
iter::once(case_token.0.span)
.chain(
operand
.as_ref()
.map(|i| i.span())
.into_iter()
.chain(conditions.iter().flat_map(|case_when| {
[case_when.condition.span(), case_when.result.span()]
}))
.chain(else_result.as_ref().map(|i| i.span())),
)
.chain(iter::once(end_token.0.span)),
),
Expr::Exists { subquery, .. } => subquery.span(),
Expr::Subquery(query) => query.span(),
Expand Down
5 changes: 4 additions & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,7 @@ impl<'a> Parser<'a> {
}

pub fn parse_case_expr(&mut self) -> Result<Expr, ParserError> {
let case_token = AttachedToken(self.get_current_token().clone());
let mut operand = None;
if !self.parse_keyword(Keyword::WHEN) {
operand = Some(Box::new(self.parse_expr()?));
Expand All @@ -2294,8 +2295,10 @@ impl<'a> Parser<'a> {
} else {
None
};
self.expect_keyword_is(Keyword::END)?;
let end_token = AttachedToken(self.expect_keyword(Keyword::END)?);
Ok(Expr::Case {
case_token,
end_token,
operand,
conditions,
else_result,
Expand Down
16 changes: 16 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6861,6 +6861,8 @@ fn parse_searched_case_expr() {
let select = verified_only_select(sql);
assert_eq!(
&Case {
case_token: AttachedToken::empty(),
end_token: AttachedToken::empty(),
operand: None,
conditions: vec![
CaseWhen {
Expand Down Expand Up @@ -6900,6 +6902,8 @@ fn parse_simple_case_expr() {
use self::Expr::{Case, Identifier};
assert_eq!(
&Case {
case_token: AttachedToken::empty(),
end_token: AttachedToken::empty(),
operand: Some(Box::new(Identifier(Ident::new("foo")))),
conditions: vec![CaseWhen {
condition: Expr::value(number("1")),
Expand Down Expand Up @@ -14464,6 +14468,16 @@ fn test_case_statement_span() {
);
}

#[test]
fn test_case_expr_span() {
let sql = "CASE 1 WHEN 2 THEN 3 ELSE 4 END";
let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap();
assert_eq!(
parser.parse_expr().unwrap().span(),
Span::new(Location::new(1, 1), Location::new(1, sql.len() as u64 + 1))
);
}

#[test]
fn parse_if_statement() {
let dialects = all_dialects_except(|d| d.is::<MsSqlDialect>());
Expand Down Expand Up @@ -14642,6 +14656,8 @@ fn test_lambdas() {
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
body: Box::new(Expr::Case {
case_token: AttachedToken::empty(),
end_token: AttachedToken::empty(),
operand: None,
conditions: vec![
CaseWhen {
Expand Down
3 changes: 3 additions & 0 deletions tests/sqlparser_databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use sqlparser::ast::helpers::attached_token::AttachedToken;
use sqlparser::ast::*;
use sqlparser::dialect::{DatabricksDialect, GenericDialect};
use sqlparser::parser::ParserError;
Expand Down Expand Up @@ -108,6 +109,8 @@ fn test_databricks_lambdas() {
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
body: Box::new(Expr::Case {
case_token: AttachedToken::empty(),
end_token: AttachedToken::empty(),
operand: None,
conditions: vec![
CaseWhen {
Expand Down