Skip to content

Commit 9ef66cf

Browse files
GearsDatapackslpil
authored andcommitted
Fix extracting tail-position statements
1 parent dc632c2 commit 9ef66cf

File tree

4 files changed

+230
-67
lines changed

4 files changed

+230
-67
lines changed

compiler-core/src/language_server/code_action.rs

Lines changed: 181 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8450,6 +8450,10 @@ pub struct ExtractFunction<'a> {
84508450
edits: TextEdits<'a>,
84518451
function: Option<ExtractedFunction<'a>>,
84528452
function_end_position: Option<u32>,
8453+
/// Since the `visit_typed_statement` visitor function doesn't tell us when
8454+
/// a statement is the last in a block or function, we need to track that
8455+
/// manually.
8456+
last_statement_location: Option<SrcSpan>,
84538457
}
84548458

84558459
/// Information about a section of code we are extracting as a function.
@@ -8479,15 +8483,65 @@ impl<'a> ExtractedFunction<'a> {
84798483
fn location(&self) -> SrcSpan {
84808484
match &self.value {
84818485
ExtractedValue::Expression(expression) => expression.location(),
8482-
ExtractedValue::Statements(location) => *location,
8486+
ExtractedValue::Statements { location, .. } => *location,
84838487
}
84848488
}
84858489
}
84868490

84878491
#[derive(Debug)]
84888492
enum ExtractedValue<'a> {
84898493
Expression(&'a TypedExpr),
8490-
Statements(SrcSpan),
8494+
Statements {
8495+
location: SrcSpan,
8496+
position: StatementPosition,
8497+
},
8498+
}
8499+
8500+
/// When we are extracting multiple statements, there are two possible cases:
8501+
/// The first is if we are extracting statements in the middle of a function.
8502+
/// In this case, we will need to return some number of arguments, or `Nil`.
8503+
/// For example:
8504+
///
8505+
/// ```gleam
8506+
/// pub fn main() {
8507+
/// let message = "Hello!"
8508+
/// let log_message = "[INFO] " <> message
8509+
/// //^ Select from here
8510+
/// io.println(log_message)
8511+
/// // ^ Until here
8512+
///
8513+
/// do_some_more_things()
8514+
/// }
8515+
/// ```
8516+
///
8517+
/// Here, the extracted function doesn't bind any variables which we need
8518+
/// afterwards, it purely performs side effects. In this case we can just return
8519+
/// `Nil` from the new function.
8520+
///
8521+
/// However, consider the following:
8522+
///
8523+
/// ```gleam
8524+
/// pub fn main() {
8525+
/// let a = 1
8526+
/// let b = 2
8527+
/// //^ Select from here
8528+
/// a + b
8529+
/// // ^ Until here
8530+
/// }
8531+
/// ```
8532+
///
8533+
/// Here, despite us not needing any variables from the extracted code, there
8534+
/// is one key difference: the `a + b` expression is at the end of the function,
8535+
/// and so its value is returned from the entire function. This is known as the
8536+
/// "tail" position. In that case, we can't return `Nil` as that would make the
8537+
/// `main` function return `Nil` instead of the result of the addition. If we
8538+
/// extract the tail-position statement, we need to return that last value rather
8539+
/// than `Nil`.
8540+
///
8541+
#[derive(Debug)]
8542+
enum StatementPosition {
8543+
Tail { type_: Arc<Type> },
8544+
NotTail,
84918545
}
84928546

84938547
impl<'a> ExtractFunction<'a> {
@@ -8502,6 +8556,7 @@ impl<'a> ExtractFunction<'a> {
85028556
edits: TextEdits::new(line_numbers),
85038557
function: None,
85048558
function_end_position: None,
8559+
last_statement_location: None,
85058560
}
85068561
}
85078562

@@ -8524,15 +8579,88 @@ impl<'a> ExtractFunction<'a> {
85248579
};
85258580

85268581
match extracted.value {
8527-
ExtractedValue::Expression(expression) => {
8528-
self.extract_expression(expression, extracted.parameters, end)
8582+
// If we extract a block, it isn't very helpful to have the body of the
8583+
// extracted function just be a single block expression, so instead we
8584+
// extract the statements inside the block. For example, the following
8585+
// code:
8586+
//
8587+
// ```gleam
8588+
// pub fn main() {
8589+
// let x = {
8590+
// // ^ Select from here
8591+
// let a = 1
8592+
// let b = 2
8593+
// a + b
8594+
// }
8595+
// //^ Until here
8596+
// x
8597+
// }
8598+
// ```
8599+
//
8600+
// Would produce the following extracted function:
8601+
//
8602+
// ```gleam
8603+
// fn function() {
8604+
// let a = 1
8605+
// let b = 2
8606+
// a + b
8607+
// }
8608+
// ```
8609+
//
8610+
// Rather than:
8611+
//
8612+
// ```gleam
8613+
// fn function() {
8614+
// {
8615+
// let a = 1
8616+
// let b = 2
8617+
// a + b
8618+
// }
8619+
// }
8620+
// ```
8621+
//
8622+
ExtractedValue::Expression(TypedExpr::Block {
8623+
statements,
8624+
location: full_location,
8625+
}) => {
8626+
let location = SrcSpan::new(
8627+
statements.first().location().start,
8628+
statements.last().location().end,
8629+
);
8630+
self.extract_code_in_tail_position(
8631+
*full_location,
8632+
location,
8633+
statements.last().type_(),
8634+
extracted.parameters,
8635+
end,
8636+
)
85298637
}
8530-
ExtractedValue::Statements(location) => self.extract_statements(
8638+
ExtractedValue::Expression(expression) => self.extract_code_in_tail_position(
8639+
expression.location(),
8640+
expression.location(),
8641+
expression.type_(),
8642+
extracted.parameters,
8643+
end,
8644+
),
8645+
ExtractedValue::Statements {
8646+
location,
8647+
position: StatementPosition::NotTail,
8648+
} => self.extract_statements(
85318649
location,
85328650
extracted.parameters,
85338651
extracted.returned_variables,
85348652
end,
85358653
),
8654+
ExtractedValue::Statements {
8655+
location,
8656+
position: StatementPosition::Tail { type_ },
8657+
} => self.extract_code_in_tail_position(
8658+
location,
8659+
location,
8660+
type_,
8661+
extracted.parameters,
8662+
end,
8663+
),
85368664
}
85378665

85388666
let mut action = Vec::with_capacity(1);
@@ -8561,62 +8689,17 @@ impl<'a> ExtractFunction<'a> {
85618689
}
85628690
}
85638691

8564-
fn extract_expression(
8692+
/// Extracts code from the end of a function or block. This could either be
8693+
/// a single expression, or multiple statements followed by a final expression.
8694+
fn extract_code_in_tail_position(
85658695
&mut self,
8566-
expression: &TypedExpr,
8696+
location: SrcSpan,
8697+
code_location: SrcSpan,
8698+
type_: Arc<Type>,
85678699
parameters: Vec<(EcoString, Arc<Type>)>,
85688700
function_end: u32,
85698701
) {
8570-
// If we extract a block, it isn't very helpful to have the body of the
8571-
// extracted function just be a single block expression, so instead we
8572-
// extract the statements inside the block. For example, the following
8573-
// code:
8574-
//
8575-
// ```gleam
8576-
// pub fn main() {
8577-
// let x = {
8578-
// // ^ Select from here
8579-
// let a = 1
8580-
// let b = 2
8581-
// a + b
8582-
// }
8583-
// //^ Until here
8584-
// x
8585-
// }
8586-
// ```
8587-
//
8588-
// Would produce the following extracted function:
8589-
//
8590-
// ```gleam
8591-
// fn function() {
8592-
// let a = 1
8593-
// let b = 2
8594-
// a + b
8595-
// }
8596-
// ```
8597-
//
8598-
// Rather than:
8599-
//
8600-
// ```gleam
8601-
// fn function() {
8602-
// {
8603-
// let a = 1
8604-
// let b = 2
8605-
// a + b
8606-
// }
8607-
// }
8608-
// ```
8609-
//
8610-
let extracted_code_location = if let TypedExpr::Block { statements, .. } = expression {
8611-
SrcSpan::new(
8612-
statements.first().location().start,
8613-
statements.last().location().end,
8614-
)
8615-
} else {
8616-
expression.location()
8617-
};
8618-
8619-
let expression_code = code_at(self.module, extracted_code_location);
8702+
let expression_code = code_at(self.module, code_location);
86208703

86218704
let name = self.function_name();
86228705
let arguments = parameters.iter().map(|(name, _)| name).join(", ");
@@ -8626,15 +8709,15 @@ impl<'a> ExtractFunction<'a> {
86268709
// it with the call and preserve all other semantics; only one value can
86278710
// be returned from the expression, unlike when extracting multiple
86288711
// statements.
8629-
self.edits.replace(expression.location(), call);
8712+
self.edits.replace(location, call);
86308713

86318714
let mut printer = Printer::new(&self.module.ast.names);
86328715

86338716
let parameters = parameters
86348717
.iter()
86358718
.map(|(name, type_)| eco_format!("{name}: {}", printer.print_type(type_)))
86368719
.join(", ");
8637-
let return_type = printer.print_type(&expression.type_());
8720+
let return_type = printer.print_type(&type_);
86388721

86398722
let function = format!(
86408723
"\n\nfn {name}({parameters}) -> {return_type} {{
@@ -8861,11 +8944,25 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
88618944

88628945
if within(self.params.range, range) {
88638946
self.function_end_position = Some(function.end_position);
8947+
self.last_statement_location = function.body.last().map(|last| last.location());
88648948

88658949
ast::visit::visit_typed_function(self, function);
88668950
}
88678951
}
88688952

8953+
fn visit_typed_expr_block(
8954+
&mut self,
8955+
location: &'ast SrcSpan,
8956+
statements: &'ast [TypedStatement],
8957+
) {
8958+
let last_statement_location = self.last_statement_location;
8959+
self.last_statement_location = statements.last().map(|last| last.location());
8960+
8961+
ast::visit::visit_typed_expr_block(self, location, statements);
8962+
8963+
self.last_statement_location = last_statement_location;
8964+
}
8965+
88698966
fn visit_typed_expr(&mut self, expression: &'ast TypedExpr) {
88708967
// If we have already determined what code we want to extract, we don't
88718968
// want to extract this instead. This expression would be inside the
@@ -8884,12 +8981,24 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
88848981
}
88858982

88868983
fn visit_typed_statement(&mut self, statement: &'ast TypedStatement) {
8887-
if self.can_extract(statement.location()) {
8984+
let location = statement.location();
8985+
if self.can_extract(location) {
8986+
let position = if let Some(last_statement_location) = self.last_statement_location
8987+
&& location == last_statement_location
8988+
{
8989+
StatementPosition::Tail {
8990+
type_: statement.type_(),
8991+
}
8992+
} else {
8993+
StatementPosition::NotTail
8994+
};
8995+
88888996
match &mut self.function {
88898997
None => {
8890-
self.function = Some(ExtractedFunction::new(ExtractedValue::Statements(
8891-
statement.location(),
8892-
)));
8998+
self.function = Some(ExtractedFunction::new(ExtractedValue::Statements {
8999+
location,
9000+
position,
9001+
}));
88939002
}
88949003
// If we have already chosen an expression to extract, that means
88959004
// that this statement is within the already extracted expression,
@@ -8902,9 +9011,16 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
89029011
// be included within list, so we merge th spans to ensure it is
89039012
// included.
89049013
Some(ExtractedFunction {
8905-
value: ExtractedValue::Statements(location),
9014+
value:
9015+
ExtractedValue::Statements {
9016+
location,
9017+
position: extracted_position,
9018+
},
89069019
..
8907-
}) => *location = location.merge(&statement.location()),
9020+
}) => {
9021+
*location = location.merge(&statement.location());
9022+
*extracted_position = position;
9023+
}
89089024
}
89099025
}
89109026
ast::visit::visit_typed_statement(self, statement);

compiler-core/src/language_server/tests/action.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10537,3 +10537,20 @@ pub fn other() {
1053710537
find_position_of("let a").select_until(find_position_of("let b"))
1053810538
);
1053910539
}
10540+
10541+
#[test]
10542+
fn extract_statements_in_tail_position() {
10543+
assert_code_action!(
10544+
EXTRACT_FUNCTION,
10545+
r#"
10546+
pub fn main() {
10547+
let a = 1
10548+
let b = 2
10549+
let c = 3
10550+
let d = 4
10551+
a * b + c * d
10552+
}
10553+
"#,
10554+
find_position_of("let c").select_until(find_position_of("* d"))
10555+
);
10556+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
---
2+
source: compiler-core/src/language_server/tests/action.rs
3+
expression: "\npub fn main() {\n let a = 1\n let b = 2\n let c = 3\n let d = 4\n a * b + c * d\n}\n"
4+
---
5+
----- BEFORE ACTION
6+
7+
pub fn main() {
8+
let a = 1
9+
let b = 2
10+
let c = 3
11+
▔▔▔▔▔▔▔▔▔
12+
let d = 4
13+
▔▔▔▔▔▔▔▔▔▔▔
14+
a * b + c * d
15+
▔▔▔▔▔▔▔▔▔▔▔▔↑
16+
}
17+
18+
19+
----- AFTER ACTION
20+
21+
pub fn main() {
22+
let a = 1
23+
let b = 2
24+
function(a, b)
25+
}
26+
27+
fn function(a: Int, b: Int) -> Int {
28+
let c = 3
29+
let d = 4
30+
a * b + c * d
31+
}

0 commit comments

Comments
 (0)