diff --git a/CHANGELOG.md b/CHANGELOG.md index 41ca57af..c606d4d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Luau: Added support for parsing user-defined type functions ([#938](https://github.com/JohnnyMorganz/StyLua/issues/938)) + ### Fixed - Luau: fixed parentheses incorrectly removed in `(expr :: assertion) < foo` when multilining the expression, leading to a syntax error ([#940](https://github.com/JohnnyMorganz/StyLua/issues/940)) diff --git a/src/formatters/block.rs b/src/formatters/block.rs index 08ef6fad..a5e0db69 100644 --- a/src/formatters/block.rs +++ b/src/formatters/block.rs @@ -411,6 +411,20 @@ fn stmt_remove_leading_newlines(stmt: Stmt) -> Stmt { type_declaration.type_token(), with_type_token ), + #[cfg(feature = "luau")] + Stmt::ExportedTypeFunction(exported_type_function) => update_first_token!( + ExportedTypeFunction, + exported_type_function, + exported_type_function.export_token(), + with_export_token + ), + #[cfg(feature = "luau")] + Stmt::TypeFunction(type_function) => update_first_token!( + TypeFunction, + type_function, + type_function.type_token(), + with_type_token + ), #[cfg(feature = "lua52")] Stmt::Goto(goto) => update_first_token!(Goto, goto, goto.goto_token(), with_goto_token), #[cfg(feature = "lua52")] diff --git a/src/formatters/luau.rs b/src/formatters/luau.rs index 1505b36a..f41b9969 100644 --- a/src/formatters/luau.rs +++ b/src/formatters/luau.rs @@ -1,9 +1,12 @@ use crate::{ - context::{create_indent_trivia, create_newline_trivia, Context}, + context::{ + create_function_definition_trivia, create_indent_trivia, create_newline_trivia, Context, + }, fmt_op, fmt_symbol, formatters::{ assignment::hang_equal_token, expression::{format_expression, format_var}, + functions::format_function_body, general::{ format_contained_punctuated_multiline, format_contained_span, format_punctuated, format_symbol, format_token_reference, @@ -23,10 +26,10 @@ use crate::{ }; use full_moon::ast::{ luau::{ - CompoundAssignment, CompoundOp, ExportedTypeDeclaration, GenericDeclaration, - GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo, TypeArgument, - TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeInfo, TypeIntersection, - TypeSpecifier, TypeUnion, + CompoundAssignment, CompoundOp, ExportedTypeDeclaration, ExportedTypeFunction, + GenericDeclaration, GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo, + TypeArgument, TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeFunction, + TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion, }, punctuated::Pair, }; @@ -1351,6 +1354,63 @@ pub fn format_type_declaration_stmt( format_type_declaration(ctx, type_declaration, true, shape) } +fn format_type_function( + ctx: &Context, + type_function: &TypeFunction, + add_leading_trivia: bool, + shape: Shape, +) -> TypeFunction { + const TYPE_TOKEN_LENGTH: usize = "type ".len(); + const FUNCTION_TOKEN_LENGTH: usize = "function ".len(); + + // Calculate trivia + let trailing_trivia = vec![create_newline_trivia(ctx)]; + let function_definition_trivia = vec![create_function_definition_trivia(ctx)]; + + let mut type_token = format_symbol( + ctx, + type_function.type_token(), + &TokenReference::new( + vec![], + Token::new(TokenType::Identifier { + identifier: "type".into(), + }), + vec![Token::new(TokenType::spaces(1))], + ), + shape, + ); + + if add_leading_trivia { + let leading_trivia = vec![create_indent_trivia(ctx, shape)]; + type_token = type_token.update_leading_trivia(FormatTriviaType::Append(leading_trivia)) + } + + let function_token = fmt_symbol!(ctx, type_function.function_token(), "function ", shape); + let function_name = format_token_reference(ctx, type_function.function_name(), shape) + .update_trailing_trivia(FormatTriviaType::Append(function_definition_trivia)); + + let shape = shape + + (TYPE_TOKEN_LENGTH + + FUNCTION_TOKEN_LENGTH + + strip_trivia(&function_name).to_string().len()); + let function_body = format_function_body(ctx, type_function.function_body(), shape) + .update_trailing_trivia(FormatTriviaType::Append(trailing_trivia)); + + TypeFunction::new(function_name, function_body) + .with_type_token(type_token) + .with_function_token(function_token) +} + +/// Wrapper around `format_type_function` for statements +/// This is required as `format_type_function` is also used for ExportedTypeFunction, and we don't want leading trivia there +pub fn format_type_function_stmt( + ctx: &Context, + type_function: &TypeFunction, + shape: Shape, +) -> TypeFunction { + format_type_function(ctx, type_function, true, shape) +} + fn format_generic_parameter( ctx: &Context, generic_parameter: &GenericDeclarationParameter, @@ -1478,3 +1538,38 @@ pub fn format_exported_type_declaration( .with_export_token(export_token) .with_type_declaration(type_declaration) } + +pub fn format_exported_type_function( + ctx: &Context, + exported_type_function: &ExportedTypeFunction, + shape: Shape, +) -> ExportedTypeFunction { + // Calculate trivia + let shape = shape.reset(); + let leading_trivia = vec![create_indent_trivia(ctx, shape)]; + + let export_token = format_symbol( + ctx, + exported_type_function.export_token(), + &TokenReference::new( + vec![], + Token::new(TokenType::Identifier { + identifier: "export".into(), + }), + vec![Token::new(TokenType::spaces(1))], + ), + shape, + ) + .update_leading_trivia(FormatTriviaType::Append(leading_trivia)); + let type_function = format_type_function( + ctx, + exported_type_function.type_function(), + false, + shape + 7, // 7 = "export " + ); + + exported_type_function + .to_owned() + .with_export_token(export_token) + .with_type_function(type_function) +} diff --git a/src/formatters/stmt.rs b/src/formatters/stmt.rs index d6fb7d67..ce312ba7 100644 --- a/src/formatters/stmt.rs +++ b/src/formatters/stmt.rs @@ -2,8 +2,8 @@ use crate::formatters::lua52::{format_goto, format_goto_no_trivia, format_label}; #[cfg(feature = "luau")] use crate::formatters::luau::{ - format_compound_assignment, format_exported_type_declaration, format_type_declaration_stmt, - format_type_specifier, + format_compound_assignment, format_exported_type_declaration, format_exported_type_function, + format_type_declaration_stmt, format_type_function_stmt, format_type_specifier, }; use crate::{ context::{create_indent_trivia, create_newline_trivia, Context, FormatNode}, @@ -793,6 +793,8 @@ pub fn format_function_call_stmt( /// These are used for range formatting pub(crate) mod stmt_block { use crate::{context::Context, formatters::block::format_block, shape::Shape}; + #[cfg(feature = "luau")] + use full_moon::ast::luau::TypeFunction; use full_moon::ast::{ Call, Expression, Field, FunctionArgs, FunctionCall, Index, Prefix, Stmt, Suffix, TableConstructor, @@ -907,6 +909,17 @@ pub(crate) mod stmt_block { .with_suffixes(suffixes) } + #[cfg(feature = "luau")] + fn format_type_function_block( + ctx: &Context, + type_function: &TypeFunction, + shape: Shape, + ) -> TypeFunction { + let block = format_block(ctx, type_function.function_body().block(), shape); + let body = type_function.function_body().to_owned().with_block(block); + type_function.to_owned().with_function_body(body) + } + /// Only formats a block within an expression pub fn format_expression_block( ctx: &Context, @@ -1057,6 +1070,23 @@ pub(crate) mod stmt_block { Stmt::ExportedTypeDeclaration(node) => Stmt::ExportedTypeDeclaration(node.to_owned()), #[cfg(feature = "luau")] Stmt::TypeDeclaration(node) => Stmt::TypeDeclaration(node.to_owned()), + #[cfg(feature = "luau")] + Stmt::ExportedTypeFunction(exported_type_function) => { + let type_function = format_type_function_block( + ctx, + exported_type_function.type_function(), + block_shape, + ); + Stmt::ExportedTypeFunction( + exported_type_function + .to_owned() + .with_type_function(type_function), + ) + } + #[cfg(feature = "luau")] + Stmt::TypeFunction(type_function) => { + Stmt::TypeFunction(format_type_function_block(ctx, type_function, block_shape)) + } #[cfg(feature = "lua52")] Stmt::Goto(node) => Stmt::Goto(node.to_owned()), #[cfg(feature = "lua52")] @@ -1090,6 +1120,8 @@ pub fn format_stmt(ctx: &Context, stmt: &Stmt, shape: Shape) -> Stmt { #[cfg(feature = "luau")] CompoundAssignment = format_compound_assignment, #[cfg(feature = "luau")] ExportedTypeDeclaration = format_exported_type_declaration, #[cfg(feature = "luau")] TypeDeclaration = format_type_declaration_stmt, + #[cfg(feature = "luau")] ExportedTypeFunction = format_exported_type_function, + #[cfg(feature = "luau")] TypeFunction = format_type_function_stmt, #[cfg(feature = "lua52")] Goto = format_goto, #[cfg(feature = "lua52")] Label = format_label, }) diff --git a/src/formatters/trivia.rs b/src/formatters/trivia.rs index 97152c19..2e276696 100644 --- a/src/formatters/trivia.rs +++ b/src/formatters/trivia.rs @@ -4,8 +4,8 @@ use full_moon::ast::lua54::Attribute; use full_moon::ast::luau::{ ElseIfExpression, GenericDeclaration, GenericDeclarationParameter, GenericParameterInfo, IfExpression, IndexedTypeInfo, InterpolatedString, InterpolatedStringSegment, TypeArgument, - TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeInfo, TypeIntersection, - TypeSpecifier, TypeUnion, + TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeFunction, TypeInfo, + TypeIntersection, TypeSpecifier, TypeUnion, }; use full_moon::ast::{ punctuated::Punctuated, span::ContainedSpan, BinOp, Call, Expression, FunctionArgs, @@ -680,6 +680,18 @@ define_update_trivia!(Stmt, |this, leading, trailing| { } #[cfg(feature = "luau")] Stmt::TypeDeclaration(stmt) => Stmt::TypeDeclaration(stmt.update_trivia(leading, trailing)), + #[cfg(feature = "luau")] + Stmt::ExportedTypeFunction(stmt) => { + let export_token = stmt.export_token().update_leading_trivia(leading); + let type_function = stmt.type_function().update_trailing_trivia(trailing); + Stmt::ExportedTypeFunction( + stmt.to_owned() + .with_export_token(export_token) + .with_type_function(type_function), + ) + } + #[cfg(feature = "luau")] + Stmt::TypeFunction(stmt) => Stmt::TypeFunction(stmt.update_trivia(leading, trailing)), #[cfg(feature = "lua52")] Stmt::Goto(stmt) => Stmt::Goto( stmt.to_owned() @@ -920,6 +932,13 @@ define_update_trivia!(TypeDeclaration, |this, leading, trailing| { .with_type_definition(this.type_definition().update_trailing_trivia(trailing)) }); +#[cfg(feature = "luau")] +define_update_trivia!(TypeFunction, |this, leading, trailing| { + this.to_owned() + .with_type_token(this.type_token().update_leading_trivia(leading)) + .with_function_body(this.function_body().update_trailing_trivia(trailing)) +}); + #[cfg(feature = "luau")] define_update_trailing_trivia!(IndexedTypeInfo, |this, trailing| { match this { diff --git a/src/formatters/trivia_util.rs b/src/formatters/trivia_util.rs index 0b76dcb6..38fbc1ea 100644 --- a/src/formatters/trivia_util.rs +++ b/src/formatters/trivia_util.rs @@ -6,13 +6,14 @@ use crate::{ #[cfg(feature = "luau")] use full_moon::ast::luau::{ GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo, TypeArgument, - TypeDeclaration, TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion, + TypeDeclaration, TypeFunction, TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion, }; use full_moon::{ ast::{ punctuated::{Pair, Punctuated}, - BinOp, Block, Call, Expression, Field, FunctionArgs, Index, LastStmt, LocalAssignment, - Parameter, Prefix, Stmt, Suffix, TableConstructor, UnOp, Var, VarExpression, + BinOp, Block, Call, Expression, Field, FunctionArgs, FunctionBody, Index, LastStmt, + LocalAssignment, Parameter, Prefix, Stmt, Suffix, TableConstructor, UnOp, Var, + VarExpression, }, node::Node, tokenizer::{Token, TokenKind, TokenReference, TokenType}, @@ -263,6 +264,12 @@ impl GetTrailingTrivia for FunctionArgs { } } +impl GetTrailingTrivia for FunctionBody { + fn trailing_trivia(&self) -> Vec { + GetTrailingTrivia::trailing_trivia(self.end_token()) + } +} + impl GetTrailingTrivia for Prefix { fn trailing_trivia(&self) -> Vec { match self { @@ -522,7 +529,6 @@ pub fn take_leading_comments( ) } -#[cfg(feature = "luau")] pub fn take_trailing_trivia( node: &T, ) -> (T, Vec) { @@ -714,6 +720,13 @@ impl GetTrailingTrivia for TypeDeclaration { } } +#[cfg(feature = "luau")] +impl GetTrailingTrivia for TypeFunction { + fn trailing_trivia(&self) -> Vec { + self.function_body().trailing_trivia() + } +} + #[cfg(feature = "luau")] impl GetTrailingTrivia for TypeSpecifier { fn trailing_trivia(&self) -> Vec { @@ -873,22 +886,14 @@ pub fn get_stmt_trailing_trivia(stmt: Stmt) -> (Stmt, Vec) { end_stmt_trailing_trivia!(If, stmt) } Stmt::FunctionDeclaration(stmt) => { - let end_token = stmt.body().end_token(); - let trailing_trivia = end_token.trailing_trivia().map(|x| x.to_owned()).collect(); - let new_end_token = end_token.update_trailing_trivia(FormatTriviaType::Replace(vec![])); - - let body = stmt.body().to_owned().with_end_token(new_end_token); + let (body, trailing_trivia) = take_trailing_trivia(stmt.body()); ( Stmt::FunctionDeclaration(stmt.with_body(body)), trailing_trivia, ) } Stmt::LocalFunction(stmt) => { - let end_token = stmt.body().end_token(); - let trailing_trivia = end_token.trailing_trivia().map(|x| x.to_owned()).collect(); - let new_end_token = end_token.update_trailing_trivia(FormatTriviaType::Replace(vec![])); - - let body = stmt.body().to_owned().with_end_token(new_end_token); + let (body, trailing_trivia) = take_trailing_trivia(stmt.body()); (Stmt::LocalFunction(stmt.with_body(body)), trailing_trivia) } Stmt::NumericFor(stmt) => { @@ -922,6 +927,19 @@ pub fn get_stmt_trailing_trivia(stmt: Stmt) -> (Stmt, Vec) { let (type_declaration, trailing_trivia) = take_trailing_trivia(&stmt); (Stmt::TypeDeclaration(type_declaration), trailing_trivia) } + #[cfg(feature = "luau")] + Stmt::ExportedTypeFunction(stmt) => { + let (type_function, trailing_trivia) = take_trailing_trivia(stmt.type_function()); + ( + Stmt::ExportedTypeFunction(stmt.with_type_function(type_function)), + trailing_trivia, + ) + } + #[cfg(feature = "luau")] + Stmt::TypeFunction(stmt) => { + let (type_declaration, trailing_trivia) = take_trailing_trivia(&stmt); + (Stmt::TypeFunction(type_declaration), trailing_trivia) + } #[cfg(feature = "lua52")] Stmt::Goto(stmt) => { let trailing_trivia = stmt diff --git a/tests/inputs-luau/type-functions-1.lua b/tests/inputs-luau/type-functions-1.lua new file mode 100644 index 00000000..d00a1fd9 --- /dev/null +++ b/tests/inputs-luau/type-functions-1.lua @@ -0,0 +1,5 @@ +type function Foo(x) +end + +export type function Foo(x) +end diff --git a/tests/snapshots/tests__luau@type-functions-1.lua.snap b/tests/snapshots/tests__luau@type-functions-1.lua.snap new file mode 100644 index 00000000..2bcf0a09 --- /dev/null +++ b/tests/snapshots/tests__luau@type-functions-1.lua.snap @@ -0,0 +1,9 @@ +--- +source: tests/tests.rs +expression: "format(&contents, LuaVersion::Luau)" +input_file: tests/inputs-luau/type-functions-1.lua +snapshot_kind: text +--- +type function Foo(x) end + +export type function Foo(x) end