diff --git a/src/ast/mod.rs b/src/ast/mod.rs index d008fd631..96a966817 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1106,7 +1106,11 @@ pub enum Statement { /// `COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]` Commit { chain: bool }, /// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]` - Rollback { chain: bool }, + /// `ROLLBACK [ TRANSACTION | WORK ] TO [ SAVEPOINT ] name` + Rollback { + savepoint: Option, + chain: bool, + }, /// CREATE SCHEMA CreateSchema { schema_name: ObjectName, @@ -1184,6 +1188,8 @@ pub enum Statement { }, /// SAVEPOINT -- define a new savepoint within the current transaction Savepoint { name: Ident }, + /// RELEASE -- release a previously defined savepoint + Release { name: Ident }, // MERGE INTO statement, based on Snowflake. See Merge { // Specifies the table to merge @@ -1956,8 +1962,17 @@ impl fmt::Display for Statement { Statement::Commit { chain } => { write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },) } - Statement::Rollback { chain } => { - write!(f, "ROLLBACK{}", if *chain { " AND CHAIN" } else { "" },) + Statement::Rollback { savepoint, chain } => { + write!( + f, + "ROLLBACK{}{}", + if let Some(savepoint) = savepoint { + format!(" TO {}", savepoint) + } else { + "".to_string() + }, + if *chain { " AND CHAIN" } else { "" }, + ) } Statement::CreateSchema { schema_name, @@ -2049,6 +2064,10 @@ impl fmt::Display for Statement { write!(f, "SAVEPOINT ")?; write!(f, "{}", name) } + Statement::Release { name } => { + write!(f, "RELEASE ")?; + write!(f, "{}", name) + } Statement::Merge { table, source, diff --git a/src/parser.rs b/src/parser.rs index 75903ce3d..c28ee6f38 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -186,6 +186,7 @@ impl<'a> Parser<'a> { // by at least PostgreSQL and MySQL. Keyword::BEGIN => Ok(self.parse_begin()?), Keyword::SAVEPOINT => Ok(self.parse_savepoint()?), + Keyword::RELEASE => Ok(self.parse_release()?), Keyword::COMMIT => Ok(self.parse_commit()?), Keyword::ROLLBACK => Ok(self.parse_rollback()?), Keyword::ASSERT => Ok(self.parse_assert()?), @@ -390,6 +391,12 @@ impl<'a> Parser<'a> { Ok(Statement::Savepoint { name }) } + pub fn parse_release(&mut self) -> Result { + let _ = self.parse_keyword(Keyword::SAVEPOINT); + let name = self.parse_identifier()?; + Ok(Statement::Release { name }) + } + /// Parse an expression prefix pub fn parse_prefix(&mut self) -> Result { // PostgreSQL allows any string literal to be preceded by a type name, indicating that the @@ -4654,19 +4661,27 @@ impl<'a> Parser<'a> { } pub fn parse_commit(&mut self) -> Result { + let _ = self.parse_one_of_keywords(&[Keyword::TRANSACTION, Keyword::WORK]); Ok(Statement::Commit { chain: self.parse_commit_rollback_chain()?, }) } pub fn parse_rollback(&mut self) -> Result { + let _ = self.parse_one_of_keywords(&[Keyword::TRANSACTION, Keyword::WORK]); + let savepoint = if self.parse_keyword(Keyword::TO) { + let _ = self.parse_keyword(Keyword::SAVEPOINT); + Some(self.parse_identifier()?) + } else { + None + }; Ok(Statement::Rollback { + savepoint, chain: self.parse_commit_rollback_chain()?, }) } pub fn parse_commit_rollback_chain(&mut self) -> Result { - let _ = self.parse_one_of_keywords(&[Keyword::TRANSACTION, Keyword::WORK]); if self.parse_keyword(Keyword::AND) { let chain = !self.parse_keyword(Keyword::NO); self.expect_keyword(Keyword::CHAIN)?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index d4274b25a..ec782f10a 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -4474,12 +4474,28 @@ fn parse_commit() { #[test] fn parse_rollback() { match verified_stmt("ROLLBACK") { - Statement::Rollback { chain: false } => (), + Statement::Rollback { + savepoint: None, + chain: false, + } => (), _ => unreachable!(), } match verified_stmt("ROLLBACK AND CHAIN") { - Statement::Rollback { chain: true } => (), + Statement::Rollback { + savepoint: None, + chain: true, + } => (), + _ => unreachable!(), + } + + match verified_stmt("ROLLBACK TO foo") { + Statement::Rollback { + savepoint: Some(ident), + chain: false, + } => { + assert_eq!(ident.value, "foo") + } _ => unreachable!(), } @@ -4490,6 +4506,11 @@ fn parse_rollback() { one_statement_parses_to("ROLLBACK TRANSACTION AND CHAIN", "ROLLBACK AND CHAIN"); one_statement_parses_to("ROLLBACK WORK", "ROLLBACK"); one_statement_parses_to("ROLLBACK TRANSACTION", "ROLLBACK"); + one_statement_parses_to("ROLLBACK WORK TO foo", "ROLLBACK TO foo"); + one_statement_parses_to("ROLLBACK TRANSACTION TO foo", "ROLLBACK TO foo"); + one_statement_parses_to("ROLLBACK TO SAVEPOINT foo", "ROLLBACK TO foo"); + one_statement_parses_to("ROLLBACK WORK TO SAVEPOINT foo", "ROLLBACK TO foo"); + one_statement_parses_to("ROLLBACK TRANSACTION TO SAVEPOINT foo", "ROLLBACK TO foo"); } #[test] diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 13dd12a4a..892f61ff3 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1362,6 +1362,18 @@ fn test_savepoint() { } } +#[test] +fn test_release() { + match pg().verified_stmt("RELEASE test1") { + Statement::Release { name } => { + assert_eq!(Ident::new("test1"), name); + } + _ => unreachable!(), + } + + pg_and_generic().one_statement_parses_to("RELEASE SAVEPOINT foo", "RELEASE foo"); +} + #[test] fn parse_comments() { match pg().verified_stmt("COMMENT ON COLUMN tab.name IS 'comment'") {