Skip to content
22 changes: 22 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2947,6 +2947,17 @@ pub enum Statement {
variables: OneOrManyWithParens<ObjectName>,
value: Vec<Expr>,
},

/// ```sql
/// SET <variable> = expression [, <variable> = expression]*;
/// ```
///
/// Note: this is a MySQL-specific statement.
/// Refer to [`Dialect.supports_comma_separated_set_assignments`]
SetVariables {
variables: Vec<ObjectName>,
values: Vec<Expr>,
},
/// ```sql
/// SET TIME ZONE <value>
/// ```
Expand Down Expand Up @@ -5334,6 +5345,17 @@ impl fmt::Display for Statement {
Statement::List(command) => write!(f, "LIST {command}"),
Statement::Remove(command) => write!(f, "REMOVE {command}"),
Statement::SetSessionParam(kind) => write!(f, "SET {kind}"),
Statement::SetVariables { variables, values } => write!(
f,
"SET {}",
display_comma_separated(
&variables
.iter()
.zip(values.iter())
.map(|(var, val)| format!("{var} = {val}"))
.collect::<Vec<_>>()
)
),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ impl Spanned for Statement {
Statement::RaisError { .. } => Span::empty(),
Statement::List(..) | Statement::Remove(..) => Span::empty(),
Statement::SetSessionParam { .. } => Span::empty(),
Statement::SetVariables { .. } => Span::empty(),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ pub trait Dialect: Debug + Any {
false
}

/// Returns true if the dialect supports multiple `SET` statements
/// in a single statement.
///
/// ```sql
/// SET variable = expression [, variable = expression];
/// ```
fn supports_comma_separated_set_assignments(&self) -> bool {
false
}

/// Returns true if the dialect supports an `EXCEPT` clause following a
/// wildcard in a select list.
///
Expand Down
1 change: 1 addition & 0 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Dialect for MsSqlDialect {
fn supports_start_transaction_modifier(&self) -> bool {
true
}

fn supports_end_transaction_modifier(&self) -> bool {
true
}
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ impl Dialect for MySqlDialect {
fn supports_set_names(&self) -> bool {
true
}

fn supports_comma_separated_set_assignments(&self) -> bool {
true
}
}

/// `LOCK TABLES`
Expand Down
2 changes: 2 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ define_keywords!(
CHANNEL,
CHAR,
CHARACTER,
CHARACTERISTICS,
CHARACTERS,
CHARACTER_LENGTH,
CHARSET,
Expand Down Expand Up @@ -557,6 +558,7 @@ define_keywords!(
MULTISET,
MUTATION,
NAME,
NAMES,
NANOSECOND,
NANOSECONDS,
NATIONAL,
Expand Down
223 changes: 147 additions & 76 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10961,41 +10961,100 @@ impl<'a> Parser<'a> {
})
}

pub fn parse_set(&mut self) -> Result<Statement, ParserError> {
let modifier =
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]);
if let Some(Keyword::HIVEVAR) = modifier {
self.expect_token(&Token::Colon)?;
} else if let Some(set_role_stmt) =
self.maybe_parse(|parser| parser.parse_set_role(modifier))?
{
return Ok(set_role_stmt);
fn parse_set_values(
&mut self,
parenthesized_assignment: bool,
) -> Result<Vec<Expr>, ParserError> {
let mut values = vec![];

if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}

loop {
let value = if let Some(expr) = self.try_parse_expr_sub_query()? {
expr
} else if let Ok(expr) = self.parse_expr() {
expr
} else {
self.expected("variable value", self.peek_token())?
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
}

if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(values);
}
}

let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()]))
} else if self.dialect.supports_parenthesized_set_variables()
fn parse_set_assignment(
&mut self,
) -> Result<(OneOrManyWithParens<ObjectName>, Expr), ParserError> {
let variables = if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let variables = OneOrManyWithParens::Many(
let vars = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
.into_iter()
.map(|ident| ObjectName::from(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
variables
vars
} else {
OneOrManyWithParens::One(self.parse_object_name(false)?)
};

let names = matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES"));
if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) {
return self.expected("assignment operator", self.peek_token());
}

let values = self.parse_expr()?;

Ok((variables, values))
}

fn parse_set(&mut self) -> Result<Statement, ParserError> {
let modifier =
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]);

if let Some(Keyword::HIVEVAR) = modifier {
self.expect_token(&Token::Colon)?;
}

if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? {
return Ok(set_role_stmt);
}

if names && self.dialect.supports_set_names() {
// Handle special cases first
if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE])
|| self.parse_keyword(Keyword::TIMEZONE)
{
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: modifier == Some(Keyword::HIVEVAR),
variables: OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])),
value: self.parse_set_values(false)?,
});
} else if self.dialect.is::<PostgreSqlDialect>() {
// Special case for Postgres
return Ok(Statement::SetTimeZone {
local: modifier == Some(Keyword::LOCAL),
value: self.parse_expr()?,
});
} else {
return self.expected("assignment operator", self.peek_token());
}
} else if self.dialect.supports_set_names() && self.parse_keyword(Keyword::NAMES) {
if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::SetNamesDefault {});
}

let charset_name = self.parse_identifier()?;
let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() {
Some(self.parse_literal_string()?)
Expand All @@ -11007,63 +11066,14 @@ impl<'a> Parser<'a> {
charset_name,
collation_name,
});
}

let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));

if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}

let mut values = vec![];
loop {
let value = if let Some(expr) = self.try_parse_expr_sub_query()? {
expr
} else if let Ok(expr) = self.parse_expr() {
expr
} else {
self.expected("variable value", self.peek_token())?
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
}

if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: Some(Keyword::HIVEVAR) == modifier,
variables,
value: values,
});
}
}

let OneOrManyWithParens::One(variable) = variables else {
return self.expected("set variable", self.peek_token());
};

if variable.to_string().eq_ignore_ascii_case("TIMEZONE") {
// for some db (e.g. postgresql), SET TIME ZONE <value> is an alias for SET TIMEZONE [TO|=] <value>
match self.parse_expr() {
Ok(expr) => Ok(Statement::SetTimeZone {
local: modifier == Some(Keyword::LOCAL),
value: expr,
}),
_ => self.expected("timezone value", self.peek_token())?,
}
} else if variable.to_string() == "CHARACTERISTICS" {
} else if self.parse_keyword(Keyword::CHARACTERISTICS) {
self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?;
Ok(Statement::SetTransaction {
return Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: true,
})
} else if variable.to_string() == "TRANSACTION" && modifier.is_none() {
});
} else if self.parse_keyword(Keyword::TRANSACTION) {
if self.parse_keyword(Keyword::SNAPSHOT) {
let snapshot_id = self.parse_value()?.value;
return Ok(Statement::SetTransaction {
Expand All @@ -11072,17 +11082,78 @@ impl<'a> Parser<'a> {
session: false,
});
}
Ok(Statement::SetTransaction {
return Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: false,
})
} else if self.dialect.supports_set_stmt_without_operator() {
self.prev_token();
self.parse_set_session_params()
});
}

if self.dialect.supports_comma_separated_set_assignments() {
if let Ok(v) =
self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))
{
let (vars, values): (Vec<_>, Vec<_>) = v.into_iter().unzip();

return if vars.len() > 1 {
let variables = vars
.into_iter()
.map(|v| match v {
OneOrManyWithParens::One(v) => Ok(v),
_ => self.expected("List of single identifiers", self.peek_token()),
})
.collect::<Result<_, _>>()?;

Ok(Statement::SetVariables { variables, values })
} else {
let variable = match vars.into_iter().next() {
Some(v) => Ok(v),
None => self.expected("At least one identifier", self.peek_token()),
}?;

Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: modifier == Some(Keyword::HIVEVAR),
variables: variable,
value: values,
})
};
}
}

let variables = if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let vars = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
.into_iter()
.map(|ident| ObjectName::from(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
vars
} else {
self.expected("equals sign or TO", self.peek_token())
OneOrManyWithParens::One(self.parse_object_name(false)?)
};

if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));
let values = self.parse_set_values(parenthesized_assignment)?;

return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: modifier == Some(Keyword::HIVEVAR),
variables,
value: values,
});
}

if self.dialect.supports_set_stmt_without_operator() {
self.prev_token();
return self.parse_set_session_params();
};

self.expected("equals sign or TO", self.peek_token())
}

pub fn parse_set_session_params(&mut self) -> Result<Statement, ParserError> {
Expand Down
Loading