Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions src/ast/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2922,6 +2922,26 @@ impl Spanned for RenameTableNameKind {
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
/// Whether the syntax used for the trigger object (ROW or STATEMENT) is `FOR` or `FOR EACH`.
pub enum TriggerObjectKind {
/// The `FOR` syntax is used.
For(TriggerObject),
/// The `FOR EACH` syntax is used.
ForEach(TriggerObject),
}

impl Display for TriggerObjectKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TriggerObjectKind::For(obj) => write!(f, "FOR {obj}"),
TriggerObjectKind::ForEach(obj) => write!(f, "FOR EACH {obj}"),
}
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand All @@ -2943,6 +2963,23 @@ pub struct CreateTrigger {
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql?view=sql-server-ver16#arguments)
pub or_alter: bool,
/// True if this is a temporary trigger.
///
/// Examples:
///
/// ```sql
/// CREATE TEMP TRIGGER trigger_name
/// ```
///
/// or
///
/// ```sql
/// CREATE TEMPORARY TRIGGER trigger_name;
/// CREATE TEMP TRIGGER trigger_name;
/// ```
///
/// [SQLite](https://sqlite.org/lang_createtrigger.html#temp_triggers_on_non_temp_tables)
pub temporary: bool,
/// The `OR REPLACE` clause is used to re-create the trigger if it already exists.
///
/// Example:
Expand Down Expand Up @@ -2987,6 +3024,8 @@ pub struct CreateTrigger {
/// ```
pub period: TriggerPeriod,
/// Whether the trigger period was specified before the target table name.
/// This does not refer to whether the period is BEFORE, AFTER, or INSTEAD OF,
/// but rather the position of the period clause in relation to the table name.
///
/// ```sql
/// -- period_before_table == true: Postgres, MySQL, and standard SQL
Expand All @@ -3006,9 +3045,9 @@ pub struct CreateTrigger {
pub referencing: Vec<TriggerReferencing>,
/// This specifies whether the trigger function should be fired once for
/// every row affected by the trigger event, or just once per SQL statement.
pub trigger_object: TriggerObject,
/// Whether to include the `EACH` term of the `FOR EACH`, as it is optional syntax.
pub include_each: bool,
/// This is optional in some SQL dialects, such as SQLite, and if not specified, in
/// those cases, the implied default is `FOR EACH ROW`.
pub trigger_object: Option<TriggerObjectKind>,
/// Triggering conditions
pub condition: Option<Expr>,
/// Execute logic block
Expand All @@ -3025,6 +3064,7 @@ impl Display for CreateTrigger {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let CreateTrigger {
or_alter,
temporary,
or_replace,
is_constraint,
name,
Expand All @@ -3036,15 +3076,15 @@ impl Display for CreateTrigger {
referencing,
trigger_object,
condition,
include_each,
exec_body,
statements_as,
statements,
characteristics,
} = self;
write!(
f,
"CREATE {or_alter}{or_replace}{is_constraint}TRIGGER {name} ",
"CREATE {temporary}{or_alter}{or_replace}{is_constraint}TRIGGER {name} ",
temporary = if *temporary { "TEMPORARY " } else { "" },
or_alter = if *or_alter { "OR ALTER " } else { "" },
or_replace = if *or_replace { "OR REPLACE " } else { "" },
is_constraint = if *is_constraint { "CONSTRAINT " } else { "" },
Expand Down Expand Up @@ -3076,10 +3116,8 @@ impl Display for CreateTrigger {
write!(f, " REFERENCING {}", display_separated(referencing, " "))?;
}

if *include_each {
write!(f, " FOR EACH {trigger_object}")?;
} else if exec_body.is_some() {
write!(f, " FOR {trigger_object}")?;
if let Some(trigger_object) = trigger_object {
write!(f, " {trigger_object}")?;
}
if let Some(condition) = condition {
write!(f, " WHEN {condition}")?;
Expand Down
4 changes: 2 additions & 2 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub use self::ddl::{
IdentityParameters, IdentityProperty, IdentityPropertyFormatKind, IdentityPropertyKind,
IdentityPropertyOrder, IndexColumn, IndexOption, IndexType, KeyOrIndexDisplay, Msck,
NullsDistinctOption, Owner, Partition, ProcedureParam, ReferentialAction, RenameTableNameKind,
ReplicaIdentity, TagsColumnOption, Truncate, UserDefinedTypeCompositeAttributeDef,
UserDefinedTypeRepresentation, ViewColumnDef,
ReplicaIdentity, TagsColumnOption, TriggerObjectKind, Truncate,
UserDefinedTypeCompositeAttributeDef, UserDefinedTypeRepresentation, ViewColumnDef,
};
pub use self::dml::{Delete, Insert, Update};
pub use self::operator::{BinaryOperator, UnaryOperator};
Expand Down
6 changes: 3 additions & 3 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::ast::helpers::attached_token::AttachedToken;
use crate::ast::{
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, CreateTrigger,
GranteesType, IfStatement, Statement, TriggerObject,
GranteesType, IfStatement, Statement,
};
use crate::dialect::Dialect;
use crate::keywords::{self, Keyword};
Expand Down Expand Up @@ -254,6 +254,7 @@ impl MsSqlDialect {

Ok(CreateTrigger {
or_alter,
temporary: false,
or_replace: false,
is_constraint: false,
name,
Expand All @@ -263,8 +264,7 @@ impl MsSqlDialect {
table_name,
referenced_table_name: None,
referencing: Vec::new(),
trigger_object: TriggerObject::Statement,
include_each: false,
trigger_object: None,
condition: None,
exec_body: None,
statements_as: true,
Expand Down
45 changes: 30 additions & 15 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4753,9 +4753,9 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::DOMAIN) {
self.parse_create_domain()
} else if self.parse_keyword(Keyword::TRIGGER) {
self.parse_create_trigger(or_alter, or_replace, false)
self.parse_create_trigger(temporary, or_alter, or_replace, false)
} else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) {
self.parse_create_trigger(or_alter, or_replace, true)
self.parse_create_trigger(temporary, or_alter, or_replace, true)
} else if self.parse_keyword(Keyword::MACRO) {
self.parse_create_macro(or_replace, temporary)
} else if self.parse_keyword(Keyword::SECRET) {
Expand Down Expand Up @@ -5551,7 +5551,8 @@ impl<'a> Parser<'a> {
/// DROP TRIGGER [ IF EXISTS ] name ON table_name [ CASCADE | RESTRICT ]
/// ```
pub fn parse_drop_trigger(&mut self) -> Result<Statement, ParserError> {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect | MsSqlDialect) {
if !dialect_of!(self is PostgreSqlDialect | SQLiteDialect | GenericDialect | MySqlDialect | MsSqlDialect)
{
self.prev_token();
return self.expected("an object type after DROP", self.peek_token());
}
Expand Down Expand Up @@ -5579,11 +5580,13 @@ impl<'a> Parser<'a> {

pub fn parse_create_trigger(
&mut self,
temporary: bool,
or_alter: bool,
or_replace: bool,
is_constraint: bool,
) -> Result<Statement, ParserError> {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect | MsSqlDialect) {
if !dialect_of!(self is PostgreSqlDialect | SQLiteDialect | GenericDialect | MySqlDialect | MsSqlDialect)
{
self.prev_token();
return self.expected("an object type after CREATE", self.peek_token());
}
Expand All @@ -5610,14 +5613,25 @@ impl<'a> Parser<'a> {
}
}

self.expect_keyword_is(Keyword::FOR)?;
let include_each = self.parse_keyword(Keyword::EACH);
let trigger_object =
match self.expect_one_of_keywords(&[Keyword::ROW, Keyword::STATEMENT])? {
Keyword::ROW => TriggerObject::Row,
Keyword::STATEMENT => TriggerObject::Statement,
_ => unreachable!(),
};
let trigger_object = if self.parse_keyword(Keyword::FOR) {
let include_each = self.parse_keyword(Keyword::EACH);
let trigger_object =
match self.expect_one_of_keywords(&[Keyword::ROW, Keyword::STATEMENT])? {
Keyword::ROW => TriggerObject::Row,
Keyword::STATEMENT => TriggerObject::Statement,
_ => unreachable!(),
};

Some(if include_each {
TriggerObjectKind::ForEach(trigger_object)
} else {
TriggerObjectKind::For(trigger_object)
})
} else {
let _ = self.parse_keyword(Keyword::FOR);

None
};

let condition = self
.parse_keyword(Keyword::WHEN)
Expand All @@ -5632,8 +5646,9 @@ impl<'a> Parser<'a> {
statements = Some(self.parse_conditional_statements(&[Keyword::END])?);
}

Ok(Statement::CreateTrigger(CreateTrigger {
Ok(CreateTrigger {
or_alter,
temporary,
or_replace,
is_constraint,
name,
Expand All @@ -5644,13 +5659,13 @@ impl<'a> Parser<'a> {
referenced_table_name,
referencing,
trigger_object,
include_each,
condition,
exec_body,
statements_as: false,
statements,
characteristics,
}))
}
.into())
}

pub fn parse_trigger_period(&mut self) -> Result<TriggerPeriod, ParserError> {
Expand Down
4 changes: 2 additions & 2 deletions tests/sqlparser_mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2388,6 +2388,7 @@ fn parse_create_trigger() {
create_stmt,
Statement::CreateTrigger(CreateTrigger {
or_alter: true,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("reminder1")]),
Expand All @@ -2397,8 +2398,7 @@ fn parse_create_trigger() {
table_name: ObjectName::from(vec![Ident::new("Sales"), Ident::new("Customer")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Statement,
include_each: false,
trigger_object: None,
condition: None,
exec_body: None,
statements_as: true,
Expand Down
4 changes: 2 additions & 2 deletions tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4018,6 +4018,7 @@ fn parse_create_trigger() {
create_stmt,
Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("emp_stamp")]),
Expand All @@ -4027,8 +4028,7 @@ fn parse_create_trigger() {
table_name: ObjectName::from(vec![Ident::new("emp")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: None,
exec_body: Some(TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
Expand Down
Loading