Skip to content
10 changes: 10 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ pub trait Dialect: Debug + Any {
false
}

/// Returns true if the dialect supports multiple `SET` statements
/// in a single statement.
///
/// ```sql
/// SET variable = expression [, variable = expression];
/// ```
fn supports_comma_separated_set_assignments(&self) -> bool {
false
}

/// Returns true if the dialect supports an `EXCEPT` clause following a
/// wildcard in a select list.
///
Expand Down
1 change: 1 addition & 0 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Dialect for MsSqlDialect {
fn supports_start_transaction_modifier(&self) -> bool {
true
}

fn supports_end_transaction_modifier(&self) -> bool {
true
}
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ impl Dialect for MySqlDialect {
fn supports_set_names(&self) -> bool {
true
}

fn supports_comma_separated_set_assignments(&self) -> bool {
true
}
}

/// `LOCK TABLES`
Expand Down
213 changes: 134 additions & 79 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10961,127 +10961,182 @@ impl<'a> Parser<'a> {
})
}

pub fn parse_set(&mut self) -> Result<Statement, ParserError> {
let modifier =
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]);
if let Some(Keyword::HIVEVAR) = modifier {
self.expect_token(&Token::Colon)?;
} else if let Some(set_role_stmt) =
self.maybe_parse(|parser| parser.parse_set_role(modifier))?
{
return Ok(set_role_stmt);
fn parse_set_values(
&mut self,
parenthesized_assignment: bool,
) -> Result<Vec<Expr>, ParserError> {
let mut values = vec![];

if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}

let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()]))
} else if self.dialect.supports_parenthesized_set_variables()
loop {
let value = if let Some(expr) = self.try_parse_expr_sub_query()? {
expr
} else if let Ok(expr) = self.parse_expr() {
expr
} else {
self.expected("variable value", self.peek_token())?
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
}

if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(values);
}
}

fn parse_set_assignment(
&mut self,
) -> Result<(OneOrManyWithParens<ObjectName>, Expr), ParserError> {
let variables = if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let variables = OneOrManyWithParens::Many(
let vars = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
.into_iter()
.map(|ident| ObjectName::from(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
variables
vars
} else {
OneOrManyWithParens::One(self.parse_object_name(false)?)
};

let names = matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES"));
if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) {
return self.expected("assignment operator", self.peek_token());
}

if names && self.dialect.supports_set_names() {
if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::SetNamesDefault {});
}
let values = self.parse_expr()?;

let charset_name = self.parse_identifier()?;
let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() {
Some(self.parse_literal_string()?)
} else {
None
};
Ok((variables, values))
}

return Ok(Statement::SetNames {
charset_name,
collation_name,
});
fn parse_set(&mut self) -> Result<Statement, ParserError> {
let modifier =
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]);

if let Some(Keyword::HIVEVAR) = modifier {
self.expect_token(&Token::Colon)?;
}

let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));
if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? {
return Ok(set_role_stmt);
}

if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}
if self.dialect.supports_comma_separated_set_assignments() {
if let Ok(v) =
self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))
{
let (variables, values): (Vec<_>, Vec<_>) = v.into_iter().unzip();

let mut values = vec![];
loop {
let value = if let Some(expr) = self.try_parse_expr_sub_query()? {
expr
} else if let Ok(expr) = self.parse_expr() {
expr
let variables = if variables.len() == 1 {
variables.into_iter().next().unwrap()
} else {
self.expected("variable value", self.peek_token())?
OneOrManyWithParens::Many(variables.into_iter().flatten().collect())
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
}

if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: Some(Keyword::HIVEVAR) == modifier,
hivevar: modifier == Some(Keyword::HIVEVAR),
variables,
value: values,
});
}
}

let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()]))
} else if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let variables = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
.into_iter()
.map(|ident| ObjectName::from(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
variables
} else {
OneOrManyWithParens::One(self.parse_object_name(false)?)
};

if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));
let values = self.parse_set_values(parenthesized_assignment)?;

return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: modifier == Some(Keyword::HIVEVAR),
variables,
value: values,
});
}

let OneOrManyWithParens::One(variable) = variables else {
return self.expected("set variable", self.peek_token());
};

if variable.to_string().eq_ignore_ascii_case("TIMEZONE") {
// for some db (e.g. postgresql), SET TIME ZONE <value> is an alias for SET TIMEZONE [TO|=] <value>
match self.parse_expr() {
match variable.to_string().to_ascii_uppercase().as_str() {
"NAMES" if self.dialect.supports_set_names() => {
if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::SetNamesDefault {});
}
let charset_name = self.parse_identifier()?;
let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() {
Some(self.parse_literal_string()?)
} else {
None
};

Ok(Statement::SetNames {
charset_name,
collation_name,
})
}
"TIMEZONE" => match self.parse_expr() {
Ok(expr) => Ok(Statement::SetTimeZone {
local: modifier == Some(Keyword::LOCAL),
value: expr,
}),
_ => self.expected("timezone value", self.peek_token())?,
}
} else if variable.to_string() == "CHARACTERISTICS" {
self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?;
Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: true,
})
} else if variable.to_string() == "TRANSACTION" && modifier.is_none() {
if self.parse_keyword(Keyword::SNAPSHOT) {
let snapshot_id = self.parse_value()?.value;
return Ok(Statement::SetTransaction {
modes: vec![],
snapshot: Some(snapshot_id),
_ => self.expected("timezone value", self.peek_token()),
},
"CHARACTERISTICS" => {
self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?;
Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: true,
})
}
"TRANSACTION" if modifier.is_none() => {
if self.parse_keyword(Keyword::SNAPSHOT) {
let snapshot_id = self.parse_value()?.value;
return Ok(Statement::SetTransaction {
modes: vec![],
snapshot: Some(snapshot_id),
session: false,
});
}
Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: false,
});
})
}
Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: false,
})
} else if self.dialect.supports_set_stmt_without_operator() {
self.prev_token();
self.parse_set_session_params()
} else {
self.expected("equals sign or TO", self.peek_token())
_ if self.dialect.supports_set_stmt_without_operator() => {
self.prev_token();
self.parse_set_session_params()
}
_ => self.expected("equals sign or TO", self.peek_token()),
}
}

Expand Down
34 changes: 20 additions & 14 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8728,20 +8728,6 @@ fn parse_set_time_zone() {
one_statement_parses_to("SET TIME ZONE TO 'UTC'", "SET TIMEZONE = 'UTC'");
}

#[test]
fn parse_set_time_zone_alias() {
match verified_stmt("SET TIME ZONE 'UTC'") {
Statement::SetTimeZone { local, value } => {
assert!(!local);
assert_eq!(
value,
Expr::Value((Value::SingleQuotedString("UTC".into())).with_empty_span())
);
}
_ => unreachable!(),
}
}

#[test]
fn parse_commit() {
match verified_stmt("COMMIT") {
Expand Down Expand Up @@ -14654,3 +14640,23 @@ fn parse_set_names() {
dialects.verified_stmt("SET NAMES 'utf8'");
dialects.verified_stmt("SET NAMES UTF8 COLLATE bogus");
}

#[test]
fn parse_multiple_set_statements() -> Result<(), ParserError> {
let dialects = all_dialects_where(|d| d.supports_comma_separated_set_assignments());
let stmt = dialects.parse_sql_statements("SET @a = 1, b = 2")?;

let stmt = stmt[0].clone();

match stmt {
Statement::SetVariable {
variables, value, ..
} => {
assert_eq!(value.len(), 2);
assert_eq!(variables.len(), 2);
}
_ => panic!("Expected SetVariable with 2 variables and 2 values"),
};

Ok(())
}
2 changes: 1 addition & 1 deletion tests/sqlparser_hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn parse_msck() {
}

#[test]
fn parse_set() {
fn parse_set_hivevar() {
let set = "SET HIVEVAR:name = a, b, c_d";
hive().verified_stmt(set);
}
Expand Down
14 changes: 14 additions & 0 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5638,6 +5638,20 @@ fn parse_create_type_as_enum() {
}
}

#[test]
fn parse_set_time_zone_alias() {
match pg().verified_stmt("SET TIME ZONE 'UTC'") {
Statement::SetTimeZone { local, value } => {
assert!(!local);
assert_eq!(
value,
Expr::Value((Value::SingleQuotedString("UTC".into())).with_empty_span())
);
}
_ => unreachable!(),
}
}

#[test]
fn parse_alter_type() {
struct TestCase {
Expand Down