Skip to content

Commit 74148fe

Browse files
committed
Add support for CREATE TRIGGER for SQL Server
- similar to functions & procedures, this dialect can define triggers with a multi statement block - there's no `EXECUTE` keyword here, so that means the `exec_body` used by other dialects becomes an `Option`, and our `statements` is also optional for them
1 parent 580114e commit 74148fe

File tree

6 files changed

+306
-23
lines changed

6 files changed

+306
-23
lines changed

src/ast/mod.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3646,6 +3646,7 @@ pub enum Statement {
36463646
/// ```
36473647
///
36483648
/// Postgres: <https://www.postgresql.org/docs/current/sql-createtrigger.html>
3649+
/// SQL Server: <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql>
36493650
CreateTrigger {
36503651
/// The `OR REPLACE` clause is used to re-create the trigger if it already exists.
36513652
///
@@ -3707,7 +3708,9 @@ pub enum Statement {
37073708
/// Triggering conditions
37083709
condition: Option<Expr>,
37093710
/// Execute logic block
3710-
exec_body: TriggerExecBody,
3711+
exec_body: Option<TriggerExecBody>,
3712+
/// For SQL dialects with statement(s) for a body
3713+
statements: Option<Vec<Statement>>,
37113714
/// The characteristic of the trigger, which include whether the trigger is `DEFERRABLE`, `INITIALLY DEFERRED`, or `INITIALLY IMMEDIATE`,
37123715
characteristics: Option<ConstraintCharacteristics>,
37133716
},
@@ -4079,6 +4082,11 @@ pub enum Statement {
40794082
/// RETURN scalar_expression
40804083
///
40814084
/// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql>
4085+
///
4086+
/// for Triggers:
4087+
/// RETURN
4088+
///
4089+
/// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql>
40824090
Return(ReturnStatement),
40834091
}
40844092

@@ -4508,6 +4516,7 @@ impl fmt::Display for Statement {
45084516
condition,
45094517
include_each,
45104518
exec_body,
4519+
statements,
45114520
characteristics,
45124521
} => {
45134522
write!(
@@ -4542,7 +4551,20 @@ impl fmt::Display for Statement {
45424551
if let Some(condition) = condition {
45434552
write!(f, " WHEN {condition}")?;
45444553
}
4545-
write!(f, " EXECUTE {exec_body}")
4554+
if let Some(exec_body) = exec_body {
4555+
write!(f, " EXECUTE {exec_body}")?;
4556+
}
4557+
if let Some(statements) = statements {
4558+
write!(f, " AS ")?;
4559+
if statements.len() > 1 {
4560+
write!(f, "BEGIN ")?;
4561+
}
4562+
write!(f, "{}", display_separated(statements, "; "))?;
4563+
if statements.len() > 1 {
4564+
write!(f, " END")?;
4565+
}
4566+
}
4567+
Ok(())
45464568
}
45474569
Statement::DropTrigger {
45484570
if_exists,

src/ast/trigger.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ impl fmt::Display for TriggerEvent {
110110
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
111111
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
112112
pub enum TriggerPeriod {
113+
For,
113114
After,
114115
Before,
115116
InsteadOf,
@@ -118,6 +119,7 @@ pub enum TriggerPeriod {
118119
impl fmt::Display for TriggerPeriod {
119120
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120121
match self {
122+
TriggerPeriod::For => write!(f, "FOR"),
121123
TriggerPeriod::After => write!(f, "AFTER"),
122124
TriggerPeriod::Before => write!(f, "BEFORE"),
123125
TriggerPeriod::InsteadOf => write!(f, "INSTEAD OF"),

src/parser/mod.rs

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5278,11 +5278,15 @@ impl<'a> Parser<'a> {
52785278
or_replace: bool,
52795279
is_constraint: bool,
52805280
) -> Result<Statement, ParserError> {
5281-
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect) {
5281+
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect | MsSqlDialect) {
52825282
self.prev_token();
52835283
return self.expected("an object type after CREATE", self.peek_token());
52845284
}
52855285

5286+
if dialect_of!(self is MsSqlDialect) {
5287+
return self.parse_mssql_create_trigger(or_replace, is_constraint);
5288+
}
5289+
52865290
let name = self.parse_object_name(false)?;
52875291
let period = self.parse_trigger_period()?;
52885292

@@ -5335,18 +5339,64 @@ impl<'a> Parser<'a> {
53355339
trigger_object,
53365340
include_each,
53375341
condition,
5338-
exec_body,
5342+
exec_body: Some(exec_body),
5343+
statements: None,
53395344
characteristics,
53405345
})
53415346
}
53425347

5348+
/// Parse `CREATE TRIGGER` for [MsSql]
5349+
///
5350+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql
5351+
pub fn parse_mssql_create_trigger(
5352+
&mut self,
5353+
or_replace: bool,
5354+
is_constraint: bool,
5355+
) -> Result<Statement, ParserError> {
5356+
let name = self.parse_object_name(false)?;
5357+
self.expect_keyword_is(Keyword::ON)?;
5358+
let table_name = self.parse_object_name(false)?;
5359+
let period = self.parse_trigger_period()?;
5360+
let events = self.parse_comma_separated(Parser::parse_trigger_event)?;
5361+
5362+
self.expect_keyword_is(Keyword::AS)?;
5363+
5364+
let statements = if self.peek_keyword(Keyword::BEGIN) {
5365+
self.expect_keyword_is(Keyword::BEGIN)?;
5366+
let statements = self.parse_statement_list(&[Keyword::END])?;
5367+
self.expect_keyword_is(Keyword::END)?;
5368+
statements
5369+
} else {
5370+
vec![self.parse_statement()?]
5371+
};
5372+
5373+
Ok(Statement::CreateTrigger {
5374+
or_replace,
5375+
is_constraint,
5376+
name,
5377+
period,
5378+
events,
5379+
table_name,
5380+
referenced_table_name: None,
5381+
referencing: Vec::new(),
5382+
trigger_object: TriggerObject::Statement,
5383+
include_each: false,
5384+
condition: None,
5385+
exec_body: None,
5386+
statements: Some(statements),
5387+
characteristics: None,
5388+
})
5389+
}
5390+
53435391
pub fn parse_trigger_period(&mut self) -> Result<TriggerPeriod, ParserError> {
53445392
Ok(
53455393
match self.expect_one_of_keywords(&[
5394+
Keyword::FOR,
53465395
Keyword::BEFORE,
53475396
Keyword::AFTER,
53485397
Keyword::INSTEAD,
53495398
])? {
5399+
Keyword::FOR => TriggerPeriod::For,
53505400
Keyword::BEFORE => TriggerPeriod::Before,
53515401
Keyword::AFTER => TriggerPeriod::After,
53525402
Keyword::INSTEAD => self
@@ -15130,10 +15180,12 @@ impl<'a> Parser<'a> {
1513015180

1513115181
/// Parse [Statement::Return]
1513215182
fn parse_return(&mut self) -> Result<Statement, ParserError> {
15133-
let expr = self.parse_expr()?;
15134-
Ok(Statement::Return(ReturnStatement {
15135-
value: Some(ReturnStatementValue::Expr(expr)),
15136-
}))
15183+
match self.maybe_parse(|p| p.parse_expr())? {
15184+
Some(expr) => Ok(Statement::Return(ReturnStatement {
15185+
value: Some(ReturnStatementValue::Expr(expr)),
15186+
})),
15187+
None => Ok(Statement::Return(ReturnStatement { value: None })),
15188+
}
1513715189
}
1513815190

1513915191
/// Consume the parser and return its underlying token buffer

tests/sqlparser_mssql.rs

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,6 +2420,206 @@ fn parse_mssql_merge_with_output() {
24202420
ms_and_generic().verified_stmt(stmt);
24212421
}
24222422

2423+
#[test]
2424+
fn parse_create_trigger() {
2425+
let create_trigger = r#"
2426+
CREATE TRIGGER reminder1
2427+
ON Sales.Customer
2428+
AFTER INSERT, UPDATE
2429+
AS RAISERROR ('Notify Customer Relations', 16, 10);
2430+
"#;
2431+
let create_stmt = ms().one_statement_parses_to(create_trigger, "");
2432+
assert_eq!(
2433+
create_stmt,
2434+
Statement::CreateTrigger {
2435+
or_replace: false,
2436+
is_constraint: false,
2437+
name: ObjectName::from(vec![Ident::new("reminder1")]),
2438+
period: TriggerPeriod::After,
2439+
events: vec![TriggerEvent::Insert, TriggerEvent::Update(vec![]),],
2440+
table_name: ObjectName::from(vec![Ident::new("Sales"), Ident::new("Customer")]),
2441+
referenced_table_name: None,
2442+
referencing: vec![],
2443+
trigger_object: TriggerObject::Statement,
2444+
include_each: false,
2445+
condition: None,
2446+
exec_body: None,
2447+
statements: Some(vec![Statement::RaisError {
2448+
message: Box::new(Expr::Value(
2449+
(Value::SingleQuotedString("Notify Customer Relations".to_string()))
2450+
.with_empty_span()
2451+
)),
2452+
severity: Box::new(Expr::Value(
2453+
(Value::Number("16".parse().unwrap(), false)).with_empty_span()
2454+
)),
2455+
state: Box::new(Expr::Value(
2456+
(Value::Number("10".parse().unwrap(), false)).with_empty_span()
2457+
)),
2458+
arguments: vec![],
2459+
options: vec![],
2460+
}]),
2461+
characteristics: None,
2462+
}
2463+
);
2464+
2465+
let multi_statement_trigger = r#"
2466+
CREATE TRIGGER some_trigger ON some_table FOR INSERT
2467+
AS
2468+
BEGIN
2469+
RAISERROR('Trigger fired', 10, 1);
2470+
END
2471+
"#;
2472+
let create_stmt = ms().one_statement_parses_to(multi_statement_trigger, "");
2473+
assert_eq!(
2474+
create_stmt,
2475+
Statement::CreateTrigger {
2476+
or_replace: false,
2477+
is_constraint: false,
2478+
name: ObjectName::from(vec![Ident::new("some_trigger")]),
2479+
period: TriggerPeriod::For,
2480+
events: vec![TriggerEvent::Insert],
2481+
table_name: ObjectName::from(vec![Ident::new("some_table")]),
2482+
referenced_table_name: None,
2483+
referencing: vec![],
2484+
trigger_object: TriggerObject::Statement,
2485+
include_each: false,
2486+
condition: None,
2487+
exec_body: None,
2488+
statements: Some(vec![Statement::RaisError {
2489+
message: Box::new(Expr::Value(
2490+
(Value::SingleQuotedString("Trigger fired".to_string())).with_empty_span()
2491+
)),
2492+
severity: Box::new(Expr::Value(
2493+
(Value::Number("10".parse().unwrap(), false)).with_empty_span()
2494+
)),
2495+
state: Box::new(Expr::Value(
2496+
(Value::Number("1".parse().unwrap(), false)).with_empty_span()
2497+
)),
2498+
arguments: vec![],
2499+
options: vec![],
2500+
}]),
2501+
characteristics: None,
2502+
}
2503+
);
2504+
2505+
let create_trigger_with_return = r#"
2506+
CREATE TRIGGER some_trigger ON some_table FOR INSERT
2507+
AS
2508+
BEGIN
2509+
RETURN;
2510+
END
2511+
"#;
2512+
let create_stmt = ms().one_statement_parses_to(create_trigger_with_return, "");
2513+
assert_eq!(
2514+
create_stmt,
2515+
Statement::CreateTrigger {
2516+
or_replace: false,
2517+
is_constraint: false,
2518+
name: ObjectName::from(vec![Ident::new("some_trigger")]),
2519+
period: TriggerPeriod::For,
2520+
events: vec![TriggerEvent::Insert],
2521+
table_name: ObjectName::from(vec![Ident::new("some_table")]),
2522+
referenced_table_name: None,
2523+
referencing: vec![],
2524+
trigger_object: TriggerObject::Statement,
2525+
include_each: false,
2526+
condition: None,
2527+
exec_body: None,
2528+
statements: Some(vec![Statement::Return(ReturnStatement { value: None })]),
2529+
characteristics: None,
2530+
}
2531+
);
2532+
2533+
let create_trigger_with_conditional = r#"
2534+
CREATE TRIGGER some_trigger ON some_table FOR INSERT
2535+
AS
2536+
BEGIN
2537+
IF 1=2
2538+
BEGIN
2539+
RAISERROR('Trigger fired', 10, 1);
2540+
END
2541+
2542+
RETURN;
2543+
END
2544+
"#;
2545+
let create_stmt = ms().one_statement_parses_to(create_trigger_with_conditional, "");
2546+
assert_eq!(
2547+
create_stmt,
2548+
Statement::CreateTrigger {
2549+
or_replace: false,
2550+
is_constraint: false,
2551+
name: ObjectName::from(vec![Ident::new("some_trigger")]),
2552+
period: TriggerPeriod::For,
2553+
events: vec![TriggerEvent::Insert],
2554+
table_name: ObjectName::from(vec![Ident::new("some_table")]),
2555+
referenced_table_name: None,
2556+
referencing: vec![],
2557+
trigger_object: TriggerObject::Statement,
2558+
include_each: false,
2559+
condition: None,
2560+
exec_body: None,
2561+
statements: Some(vec![
2562+
Statement::If(IfStatement {
2563+
if_block: ConditionalStatementBlock {
2564+
start_token: AttachedToken(TokenWithSpan::wrap(
2565+
sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
2566+
value: "IF".to_string(),
2567+
quote_style: None,
2568+
keyword: Keyword::IF
2569+
})
2570+
)),
2571+
condition: Some(Expr::BinaryOp {
2572+
left: Box::new(Expr::Value(number("1").with_empty_span())),
2573+
op: sqlparser::ast::BinaryOperator::Eq,
2574+
right: Box::new(Expr::Value(number("2").with_empty_span())),
2575+
}),
2576+
then_token: None,
2577+
conditional_statements: ConditionalStatements::BeginEnd(
2578+
BeginEndStatements {
2579+
begin_token: AttachedToken(TokenWithSpan::wrap(
2580+
sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
2581+
value: "BEGIN".to_string(),
2582+
quote_style: None,
2583+
keyword: Keyword::BEGIN
2584+
})
2585+
)),
2586+
statements: vec![Statement::RaisError {
2587+
message: Box::new(Expr::Value(
2588+
(Value::SingleQuotedString("Trigger fired".to_string()))
2589+
.with_empty_span()
2590+
)),
2591+
severity: Box::new(Expr::Value(
2592+
(Value::Number("10".parse().unwrap(), false))
2593+
.with_empty_span()
2594+
)),
2595+
state: Box::new(Expr::Value(
2596+
(Value::Number("1".parse().unwrap(), false))
2597+
.with_empty_span()
2598+
)),
2599+
arguments: vec![],
2600+
options: vec![],
2601+
}],
2602+
end_token: AttachedToken(TokenWithSpan::wrap(
2603+
sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
2604+
value: "END".to_string(),
2605+
quote_style: None,
2606+
keyword: Keyword::END
2607+
})
2608+
)),
2609+
}
2610+
),
2611+
},
2612+
elseif_blocks: vec![],
2613+
else_block: None,
2614+
end_token: None,
2615+
}),
2616+
Statement::Return(ReturnStatement { value: None }),
2617+
]),
2618+
characteristics: None,
2619+
}
2620+
);
2621+
}
2622+
24232623
#[test]
24242624
fn parse_drop_trigger() {
24252625
let sql_drop_trigger = "DROP TRIGGER emp_stamp;";

0 commit comments

Comments
 (0)