diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index f998143c8..773b119fb 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -1458,12 +1458,19 @@ pub struct ProcedureParam { pub name: Ident, pub data_type: DataType, pub mode: Option, + pub default: Option, } impl fmt::Display for ProcedureParam { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Some(mode) = &self.mode { - write!(f, "{mode} {} {}", self.name, self.data_type) + if let Some(default) = &self.default { + write!(f, "{mode} {} {} = {}", self.name, self.data_type, default) + } else { + write!(f, "{mode} {} {}", self.name, self.data_type) + } + } else if let Some(default) = &self.default { + write!(f, "{} {} = {}", self.name, self.data_type, default) } else { write!(f, "{} {}", self.name, self.data_type) } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index d661efd4d..9682804c2 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -7897,10 +7897,17 @@ impl<'a> Parser<'a> { }; let name = self.parse_identifier()?; let data_type = self.parse_data_type()?; + let default = if self.consume_token(&Token::Eq) { + Some(self.parse_expr()?) + } else { + None + }; + Ok(ProcedureParam { name, data_type, mode, + default, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 720c1e492..489ea3f04 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -16497,7 +16497,8 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Integer(None), - mode: Some(ArgMode::In) + mode: Some(ArgMode::In), + default: None, }, ProcedureParam { name: Ident { @@ -16506,7 +16507,8 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Text, - mode: Some(ArgMode::Out) + mode: Some(ArgMode::Out), + default: None, }, ProcedureParam { name: Ident { @@ -16515,7 +16517,8 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Timestamp(None, TimezoneInfo::None), - mode: Some(ArgMode::InOut) + mode: Some(ArgMode::InOut), + default: None, }, ProcedureParam { name: Ident { @@ -16524,13 +16527,60 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Bool, - mode: None + mode: None, + default: None, }, ]) ); } _ => unreachable!(), } + + // parameters with default values + let sql = r#"CREATE PROCEDURE test_proc (IN a INTEGER = 1, OUT b TEXT = '2', INOUT c TIMESTAMP = NULL, d BOOL = 0) AS BEGIN SELECT 1; END"#; + match verified_stmt(sql) { + Statement::CreateProcedure { + or_alter, + name, + params, + .. + } => { + assert_eq!(or_alter, false); + assert_eq!(name.to_string(), "test_proc"); + assert_eq!( + params, + Some(vec![ + ProcedureParam { + name: Ident::new("a"), + data_type: DataType::Integer(None), + mode: Some(ArgMode::In), + default: Some(Expr::Value((number("1")).with_empty_span())), + }, + ProcedureParam { + name: Ident::new("b"), + data_type: DataType::Text, + mode: Some(ArgMode::Out), + default: Some(Expr::Value( + Value::SingleQuotedString("2".into()).with_empty_span() + )), + }, + ProcedureParam { + name: Ident::new("c"), + data_type: DataType::Timestamp(None, TimezoneInfo::None), + mode: Some(ArgMode::InOut), + default: Some(Expr::Value(Value::Null.with_empty_span())), + }, + ProcedureParam { + name: Ident::new("d"), + data_type: DataType::Bool, + mode: None, + default: Some(Expr::Value((number("0")).with_empty_span())), + } + ]), + ); + } + _ => unreachable!(), + } } #[test] diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index b1ad422ec..181899170 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -156,6 +156,7 @@ fn parse_create_procedure() { }, data_type: DataType::Int(None), mode: None, + default: None, }, ProcedureParam { name: Ident { @@ -168,6 +169,7 @@ fn parse_create_procedure() { unit: None })), mode: None, + default: None, } ]), name: ObjectName::from(vec![Ident { @@ -196,6 +198,10 @@ fn parse_mssql_create_procedure() { let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN SELECT [foo], CASE WHEN [foo] IS NULL THEN 'empty' ELSE 'notempty' END AS [foo]; END"); // Multiple statements let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10; END"); + + // parameters with default values + let sql = r#"CREATE PROCEDURE foo (IN @a INTEGER = 1, OUT @b TEXT = '2', INOUT @c DATETIME = NULL, @d BOOL = 0) AS BEGIN SELECT 1; END"#; + let _ = ms().verified_stmt(sql); } #[test]