Skip to content

Commit 57408a8

Browse files
Luau: Support user-defined type functions (#947)
* Support user-defined type functions * Add test case * Update changelog
1 parent b349c0a commit 57408a8

File tree

8 files changed

+219
-23
lines changed

8 files changed

+219
-23
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- Luau: Added support for parsing user-defined type functions ([#938](https://github.com/JohnnyMorganz/StyLua/issues/938))
13+
1014
### Fixed
1115

1216
- Luau: fixed parentheses incorrectly removed in `(expr :: assertion) < foo` when multilining the expression, leading to a syntax error ([#940](https://github.com/JohnnyMorganz/StyLua/issues/940))

src/formatters/block.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,20 @@ fn stmt_remove_leading_newlines(stmt: Stmt) -> Stmt {
411411
type_declaration.type_token(),
412412
with_type_token
413413
),
414+
#[cfg(feature = "luau")]
415+
Stmt::ExportedTypeFunction(exported_type_function) => update_first_token!(
416+
ExportedTypeFunction,
417+
exported_type_function,
418+
exported_type_function.export_token(),
419+
with_export_token
420+
),
421+
#[cfg(feature = "luau")]
422+
Stmt::TypeFunction(type_function) => update_first_token!(
423+
TypeFunction,
424+
type_function,
425+
type_function.type_token(),
426+
with_type_token
427+
),
414428
#[cfg(feature = "lua52")]
415429
Stmt::Goto(goto) => update_first_token!(Goto, goto, goto.goto_token(), with_goto_token),
416430
#[cfg(feature = "lua52")]

src/formatters/luau.rs

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use crate::{
2-
context::{create_indent_trivia, create_newline_trivia, Context},
2+
context::{
3+
create_function_definition_trivia, create_indent_trivia, create_newline_trivia, Context,
4+
},
35
fmt_op, fmt_symbol,
46
formatters::{
57
assignment::hang_equal_token,
68
expression::{format_expression, format_var},
9+
functions::format_function_body,
710
general::{
811
format_contained_punctuated_multiline, format_contained_span, format_punctuated,
912
format_symbol, format_token_reference,
@@ -23,10 +26,10 @@ use crate::{
2326
};
2427
use full_moon::ast::{
2528
luau::{
26-
CompoundAssignment, CompoundOp, ExportedTypeDeclaration, GenericDeclaration,
27-
GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo, TypeArgument,
28-
TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeInfo, TypeIntersection,
29-
TypeSpecifier, TypeUnion,
29+
CompoundAssignment, CompoundOp, ExportedTypeDeclaration, ExportedTypeFunction,
30+
GenericDeclaration, GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo,
31+
TypeArgument, TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeFunction,
32+
TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion,
3033
},
3134
punctuated::Pair,
3235
};
@@ -1351,6 +1354,63 @@ pub fn format_type_declaration_stmt(
13511354
format_type_declaration(ctx, type_declaration, true, shape)
13521355
}
13531356

1357+
fn format_type_function(
1358+
ctx: &Context,
1359+
type_function: &TypeFunction,
1360+
add_leading_trivia: bool,
1361+
shape: Shape,
1362+
) -> TypeFunction {
1363+
const TYPE_TOKEN_LENGTH: usize = "type ".len();
1364+
const FUNCTION_TOKEN_LENGTH: usize = "function ".len();
1365+
1366+
// Calculate trivia
1367+
let trailing_trivia = vec![create_newline_trivia(ctx)];
1368+
let function_definition_trivia = vec![create_function_definition_trivia(ctx)];
1369+
1370+
let mut type_token = format_symbol(
1371+
ctx,
1372+
type_function.type_token(),
1373+
&TokenReference::new(
1374+
vec![],
1375+
Token::new(TokenType::Identifier {
1376+
identifier: "type".into(),
1377+
}),
1378+
vec![Token::new(TokenType::spaces(1))],
1379+
),
1380+
shape,
1381+
);
1382+
1383+
if add_leading_trivia {
1384+
let leading_trivia = vec![create_indent_trivia(ctx, shape)];
1385+
type_token = type_token.update_leading_trivia(FormatTriviaType::Append(leading_trivia))
1386+
}
1387+
1388+
let function_token = fmt_symbol!(ctx, type_function.function_token(), "function ", shape);
1389+
let function_name = format_token_reference(ctx, type_function.function_name(), shape)
1390+
.update_trailing_trivia(FormatTriviaType::Append(function_definition_trivia));
1391+
1392+
let shape = shape
1393+
+ (TYPE_TOKEN_LENGTH
1394+
+ FUNCTION_TOKEN_LENGTH
1395+
+ strip_trivia(&function_name).to_string().len());
1396+
let function_body = format_function_body(ctx, type_function.function_body(), shape)
1397+
.update_trailing_trivia(FormatTriviaType::Append(trailing_trivia));
1398+
1399+
TypeFunction::new(function_name, function_body)
1400+
.with_type_token(type_token)
1401+
.with_function_token(function_token)
1402+
}
1403+
1404+
/// Wrapper around `format_type_function` for statements
1405+
/// This is required as `format_type_function` is also used for ExportedTypeFunction, and we don't want leading trivia there
1406+
pub fn format_type_function_stmt(
1407+
ctx: &Context,
1408+
type_function: &TypeFunction,
1409+
shape: Shape,
1410+
) -> TypeFunction {
1411+
format_type_function(ctx, type_function, true, shape)
1412+
}
1413+
13541414
fn format_generic_parameter(
13551415
ctx: &Context,
13561416
generic_parameter: &GenericDeclarationParameter,
@@ -1478,3 +1538,38 @@ pub fn format_exported_type_declaration(
14781538
.with_export_token(export_token)
14791539
.with_type_declaration(type_declaration)
14801540
}
1541+
1542+
pub fn format_exported_type_function(
1543+
ctx: &Context,
1544+
exported_type_function: &ExportedTypeFunction,
1545+
shape: Shape,
1546+
) -> ExportedTypeFunction {
1547+
// Calculate trivia
1548+
let shape = shape.reset();
1549+
let leading_trivia = vec![create_indent_trivia(ctx, shape)];
1550+
1551+
let export_token = format_symbol(
1552+
ctx,
1553+
exported_type_function.export_token(),
1554+
&TokenReference::new(
1555+
vec![],
1556+
Token::new(TokenType::Identifier {
1557+
identifier: "export".into(),
1558+
}),
1559+
vec![Token::new(TokenType::spaces(1))],
1560+
),
1561+
shape,
1562+
)
1563+
.update_leading_trivia(FormatTriviaType::Append(leading_trivia));
1564+
let type_function = format_type_function(
1565+
ctx,
1566+
exported_type_function.type_function(),
1567+
false,
1568+
shape + 7, // 7 = "export "
1569+
);
1570+
1571+
exported_type_function
1572+
.to_owned()
1573+
.with_export_token(export_token)
1574+
.with_type_function(type_function)
1575+
}

src/formatters/stmt.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
use crate::formatters::lua52::{format_goto, format_goto_no_trivia, format_label};
33
#[cfg(feature = "luau")]
44
use crate::formatters::luau::{
5-
format_compound_assignment, format_exported_type_declaration, format_type_declaration_stmt,
6-
format_type_specifier,
5+
format_compound_assignment, format_exported_type_declaration, format_exported_type_function,
6+
format_type_declaration_stmt, format_type_function_stmt, format_type_specifier,
77
};
88
use crate::{
99
context::{create_indent_trivia, create_newline_trivia, Context, FormatNode},
@@ -793,6 +793,8 @@ pub fn format_function_call_stmt(
793793
/// These are used for range formatting
794794
pub(crate) mod stmt_block {
795795
use crate::{context::Context, formatters::block::format_block, shape::Shape};
796+
#[cfg(feature = "luau")]
797+
use full_moon::ast::luau::TypeFunction;
796798
use full_moon::ast::{
797799
Call, Expression, Field, FunctionArgs, FunctionCall, Index, Prefix, Stmt, Suffix,
798800
TableConstructor,
@@ -907,6 +909,17 @@ pub(crate) mod stmt_block {
907909
.with_suffixes(suffixes)
908910
}
909911

912+
#[cfg(feature = "luau")]
913+
fn format_type_function_block(
914+
ctx: &Context,
915+
type_function: &TypeFunction,
916+
shape: Shape,
917+
) -> TypeFunction {
918+
let block = format_block(ctx, type_function.function_body().block(), shape);
919+
let body = type_function.function_body().to_owned().with_block(block);
920+
type_function.to_owned().with_function_body(body)
921+
}
922+
910923
/// Only formats a block within an expression
911924
pub fn format_expression_block(
912925
ctx: &Context,
@@ -1057,6 +1070,23 @@ pub(crate) mod stmt_block {
10571070
Stmt::ExportedTypeDeclaration(node) => Stmt::ExportedTypeDeclaration(node.to_owned()),
10581071
#[cfg(feature = "luau")]
10591072
Stmt::TypeDeclaration(node) => Stmt::TypeDeclaration(node.to_owned()),
1073+
#[cfg(feature = "luau")]
1074+
Stmt::ExportedTypeFunction(exported_type_function) => {
1075+
let type_function = format_type_function_block(
1076+
ctx,
1077+
exported_type_function.type_function(),
1078+
block_shape,
1079+
);
1080+
Stmt::ExportedTypeFunction(
1081+
exported_type_function
1082+
.to_owned()
1083+
.with_type_function(type_function),
1084+
)
1085+
}
1086+
#[cfg(feature = "luau")]
1087+
Stmt::TypeFunction(type_function) => {
1088+
Stmt::TypeFunction(format_type_function_block(ctx, type_function, block_shape))
1089+
}
10601090
#[cfg(feature = "lua52")]
10611091
Stmt::Goto(node) => Stmt::Goto(node.to_owned()),
10621092
#[cfg(feature = "lua52")]
@@ -1090,6 +1120,8 @@ pub fn format_stmt(ctx: &Context, stmt: &Stmt, shape: Shape) -> Stmt {
10901120
#[cfg(feature = "luau")] CompoundAssignment = format_compound_assignment,
10911121
#[cfg(feature = "luau")] ExportedTypeDeclaration = format_exported_type_declaration,
10921122
#[cfg(feature = "luau")] TypeDeclaration = format_type_declaration_stmt,
1123+
#[cfg(feature = "luau")] ExportedTypeFunction = format_exported_type_function,
1124+
#[cfg(feature = "luau")] TypeFunction = format_type_function_stmt,
10931125
#[cfg(feature = "lua52")] Goto = format_goto,
10941126
#[cfg(feature = "lua52")] Label = format_label,
10951127
})

src/formatters/trivia.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use full_moon::ast::lua54::Attribute;
44
use full_moon::ast::luau::{
55
ElseIfExpression, GenericDeclaration, GenericDeclarationParameter, GenericParameterInfo,
66
IfExpression, IndexedTypeInfo, InterpolatedString, InterpolatedStringSegment, TypeArgument,
7-
TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeInfo, TypeIntersection,
8-
TypeSpecifier, TypeUnion,
7+
TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeFunction, TypeInfo,
8+
TypeIntersection, TypeSpecifier, TypeUnion,
99
};
1010
use full_moon::ast::{
1111
punctuated::Punctuated, span::ContainedSpan, BinOp, Call, Expression, FunctionArgs,
@@ -680,6 +680,18 @@ define_update_trivia!(Stmt, |this, leading, trailing| {
680680
}
681681
#[cfg(feature = "luau")]
682682
Stmt::TypeDeclaration(stmt) => Stmt::TypeDeclaration(stmt.update_trivia(leading, trailing)),
683+
#[cfg(feature = "luau")]
684+
Stmt::ExportedTypeFunction(stmt) => {
685+
let export_token = stmt.export_token().update_leading_trivia(leading);
686+
let type_function = stmt.type_function().update_trailing_trivia(trailing);
687+
Stmt::ExportedTypeFunction(
688+
stmt.to_owned()
689+
.with_export_token(export_token)
690+
.with_type_function(type_function),
691+
)
692+
}
693+
#[cfg(feature = "luau")]
694+
Stmt::TypeFunction(stmt) => Stmt::TypeFunction(stmt.update_trivia(leading, trailing)),
683695
#[cfg(feature = "lua52")]
684696
Stmt::Goto(stmt) => Stmt::Goto(
685697
stmt.to_owned()
@@ -920,6 +932,13 @@ define_update_trivia!(TypeDeclaration, |this, leading, trailing| {
920932
.with_type_definition(this.type_definition().update_trailing_trivia(trailing))
921933
});
922934

935+
#[cfg(feature = "luau")]
936+
define_update_trivia!(TypeFunction, |this, leading, trailing| {
937+
this.to_owned()
938+
.with_type_token(this.type_token().update_leading_trivia(leading))
939+
.with_function_body(this.function_body().update_trailing_trivia(trailing))
940+
});
941+
923942
#[cfg(feature = "luau")]
924943
define_update_trailing_trivia!(IndexedTypeInfo, |this, trailing| {
925944
match this {

src/formatters/trivia_util.rs

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ use crate::{
66
#[cfg(feature = "luau")]
77
use full_moon::ast::luau::{
88
GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo, TypeArgument,
9-
TypeDeclaration, TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion,
9+
TypeDeclaration, TypeFunction, TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion,
1010
};
1111
use full_moon::{
1212
ast::{
1313
punctuated::{Pair, Punctuated},
14-
BinOp, Block, Call, Expression, Field, FunctionArgs, Index, LastStmt, LocalAssignment,
15-
Parameter, Prefix, Stmt, Suffix, TableConstructor, UnOp, Var, VarExpression,
14+
BinOp, Block, Call, Expression, Field, FunctionArgs, FunctionBody, Index, LastStmt,
15+
LocalAssignment, Parameter, Prefix, Stmt, Suffix, TableConstructor, UnOp, Var,
16+
VarExpression,
1617
},
1718
node::Node,
1819
tokenizer::{Token, TokenKind, TokenReference, TokenType},
@@ -263,6 +264,12 @@ impl GetTrailingTrivia for FunctionArgs {
263264
}
264265
}
265266

267+
impl GetTrailingTrivia for FunctionBody {
268+
fn trailing_trivia(&self) -> Vec<Token> {
269+
GetTrailingTrivia::trailing_trivia(self.end_token())
270+
}
271+
}
272+
266273
impl GetTrailingTrivia for Prefix {
267274
fn trailing_trivia(&self) -> Vec<Token> {
268275
match self {
@@ -522,7 +529,6 @@ pub fn take_leading_comments<T: GetLeadingTrivia + UpdateLeadingTrivia>(
522529
)
523530
}
524531

525-
#[cfg(feature = "luau")]
526532
pub fn take_trailing_trivia<T: GetTrailingTrivia + UpdateTrailingTrivia>(
527533
node: &T,
528534
) -> (T, Vec<Token>) {
@@ -714,6 +720,13 @@ impl GetTrailingTrivia for TypeDeclaration {
714720
}
715721
}
716722

723+
#[cfg(feature = "luau")]
724+
impl GetTrailingTrivia for TypeFunction {
725+
fn trailing_trivia(&self) -> Vec<Token> {
726+
self.function_body().trailing_trivia()
727+
}
728+
}
729+
717730
#[cfg(feature = "luau")]
718731
impl GetTrailingTrivia for TypeSpecifier {
719732
fn trailing_trivia(&self) -> Vec<Token> {
@@ -873,22 +886,14 @@ pub fn get_stmt_trailing_trivia(stmt: Stmt) -> (Stmt, Vec<Token>) {
873886
end_stmt_trailing_trivia!(If, stmt)
874887
}
875888
Stmt::FunctionDeclaration(stmt) => {
876-
let end_token = stmt.body().end_token();
877-
let trailing_trivia = end_token.trailing_trivia().map(|x| x.to_owned()).collect();
878-
let new_end_token = end_token.update_trailing_trivia(FormatTriviaType::Replace(vec![]));
879-
880-
let body = stmt.body().to_owned().with_end_token(new_end_token);
889+
let (body, trailing_trivia) = take_trailing_trivia(stmt.body());
881890
(
882891
Stmt::FunctionDeclaration(stmt.with_body(body)),
883892
trailing_trivia,
884893
)
885894
}
886895
Stmt::LocalFunction(stmt) => {
887-
let end_token = stmt.body().end_token();
888-
let trailing_trivia = end_token.trailing_trivia().map(|x| x.to_owned()).collect();
889-
let new_end_token = end_token.update_trailing_trivia(FormatTriviaType::Replace(vec![]));
890-
891-
let body = stmt.body().to_owned().with_end_token(new_end_token);
896+
let (body, trailing_trivia) = take_trailing_trivia(stmt.body());
892897
(Stmt::LocalFunction(stmt.with_body(body)), trailing_trivia)
893898
}
894899
Stmt::NumericFor(stmt) => {
@@ -922,6 +927,19 @@ pub fn get_stmt_trailing_trivia(stmt: Stmt) -> (Stmt, Vec<Token>) {
922927
let (type_declaration, trailing_trivia) = take_trailing_trivia(&stmt);
923928
(Stmt::TypeDeclaration(type_declaration), trailing_trivia)
924929
}
930+
#[cfg(feature = "luau")]
931+
Stmt::ExportedTypeFunction(stmt) => {
932+
let (type_function, trailing_trivia) = take_trailing_trivia(stmt.type_function());
933+
(
934+
Stmt::ExportedTypeFunction(stmt.with_type_function(type_function)),
935+
trailing_trivia,
936+
)
937+
}
938+
#[cfg(feature = "luau")]
939+
Stmt::TypeFunction(stmt) => {
940+
let (type_declaration, trailing_trivia) = take_trailing_trivia(&stmt);
941+
(Stmt::TypeFunction(type_declaration), trailing_trivia)
942+
}
925943
#[cfg(feature = "lua52")]
926944
Stmt::Goto(stmt) => {
927945
let trailing_trivia = stmt
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
type function Foo(x)
2+
end
3+
4+
export type function Foo(x)
5+
end
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
source: tests/tests.rs
3+
expression: "format(&contents, LuaVersion::Luau)"
4+
input_file: tests/inputs-luau/type-functions-1.lua
5+
snapshot_kind: text
6+
---
7+
type function Foo(x) end
8+
9+
export type function Foo(x) end

0 commit comments

Comments
 (0)