@@ -8450,6 +8450,10 @@ pub struct ExtractFunction<'a> {
8450
8450
edits : TextEdits < ' a > ,
8451
8451
function : Option < ExtractedFunction < ' a > > ,
8452
8452
function_end_position : Option < u32 > ,
8453
+ /// Since the `visit_typed_statement` visitor function doesn't tell us when
8454
+ /// a statement is the last in a block or function, we need to track that
8455
+ /// manually.
8456
+ last_statement_location : Option < SrcSpan > ,
8453
8457
}
8454
8458
8455
8459
/// Information about a section of code we are extracting as a function.
@@ -8479,15 +8483,65 @@ impl<'a> ExtractedFunction<'a> {
8479
8483
fn location ( & self ) -> SrcSpan {
8480
8484
match & self . value {
8481
8485
ExtractedValue :: Expression ( expression) => expression. location ( ) ,
8482
- ExtractedValue :: Statements ( location) => * location,
8486
+ ExtractedValue :: Statements { location, .. } => * location,
8483
8487
}
8484
8488
}
8485
8489
}
8486
8490
8487
8491
#[ derive( Debug ) ]
8488
8492
enum ExtractedValue < ' a > {
8489
8493
Expression ( & ' a TypedExpr ) ,
8490
- Statements ( SrcSpan ) ,
8494
+ Statements {
8495
+ location : SrcSpan ,
8496
+ position : StatementPosition ,
8497
+ } ,
8498
+ }
8499
+
8500
+ /// When we are extracting multiple statements, there are two possible cases:
8501
+ /// The first is if we are extracting statements in the middle of a function.
8502
+ /// In this case, we will need to return some number of arguments, or `Nil`.
8503
+ /// For example:
8504
+ ///
8505
+ /// ```gleam
8506
+ /// pub fn main() {
8507
+ /// let message = "Hello!"
8508
+ /// let log_message = "[INFO] " <> message
8509
+ /// //^ Select from here
8510
+ /// io.println(log_message)
8511
+ /// // ^ Until here
8512
+ ///
8513
+ /// do_some_more_things()
8514
+ /// }
8515
+ /// ```
8516
+ ///
8517
+ /// Here, the extracted function doesn't bind any variables which we need
8518
+ /// afterwards, it purely performs side effects. In this case we can just return
8519
+ /// `Nil` from the new function.
8520
+ ///
8521
+ /// However, consider the following:
8522
+ ///
8523
+ /// ```gleam
8524
+ /// pub fn main() {
8525
+ /// let a = 1
8526
+ /// let b = 2
8527
+ /// //^ Select from here
8528
+ /// a + b
8529
+ /// // ^ Until here
8530
+ /// }
8531
+ /// ```
8532
+ ///
8533
+ /// Here, despite us not needing any variables from the extracted code, there
8534
+ /// is one key difference: the `a + b` expression is at the end of the function,
8535
+ /// and so its value is returned from the entire function. This is known as the
8536
+ /// "tail" position. In that case, we can't return `Nil` as that would make the
8537
+ /// `main` function return `Nil` instead of the result of the addition. If we
8538
+ /// extract the tail-position statement, we need to return that last value rather
8539
+ /// than `Nil`.
8540
+ ///
8541
+ #[ derive( Debug ) ]
8542
+ enum StatementPosition {
8543
+ Tail { type_ : Arc < Type > } ,
8544
+ NotTail ,
8491
8545
}
8492
8546
8493
8547
impl < ' a > ExtractFunction < ' a > {
@@ -8502,6 +8556,7 @@ impl<'a> ExtractFunction<'a> {
8502
8556
edits : TextEdits :: new ( line_numbers) ,
8503
8557
function : None ,
8504
8558
function_end_position : None ,
8559
+ last_statement_location : None ,
8505
8560
}
8506
8561
}
8507
8562
@@ -8524,15 +8579,88 @@ impl<'a> ExtractFunction<'a> {
8524
8579
} ;
8525
8580
8526
8581
match extracted. value {
8527
- ExtractedValue :: Expression ( expression) => {
8528
- self . extract_expression ( expression, extracted. parameters , end)
8582
+ // If we extract a block, it isn't very helpful to have the body of the
8583
+ // extracted function just be a single block expression, so instead we
8584
+ // extract the statements inside the block. For example, the following
8585
+ // code:
8586
+ //
8587
+ // ```gleam
8588
+ // pub fn main() {
8589
+ // let x = {
8590
+ // // ^ Select from here
8591
+ // let a = 1
8592
+ // let b = 2
8593
+ // a + b
8594
+ // }
8595
+ // //^ Until here
8596
+ // x
8597
+ // }
8598
+ // ```
8599
+ //
8600
+ // Would produce the following extracted function:
8601
+ //
8602
+ // ```gleam
8603
+ // fn function() {
8604
+ // let a = 1
8605
+ // let b = 2
8606
+ // a + b
8607
+ // }
8608
+ // ```
8609
+ //
8610
+ // Rather than:
8611
+ //
8612
+ // ```gleam
8613
+ // fn function() {
8614
+ // {
8615
+ // let a = 1
8616
+ // let b = 2
8617
+ // a + b
8618
+ // }
8619
+ // }
8620
+ // ```
8621
+ //
8622
+ ExtractedValue :: Expression ( TypedExpr :: Block {
8623
+ statements,
8624
+ location : full_location,
8625
+ } ) => {
8626
+ let location = SrcSpan :: new (
8627
+ statements. first ( ) . location ( ) . start ,
8628
+ statements. last ( ) . location ( ) . end ,
8629
+ ) ;
8630
+ self . extract_code_in_tail_position (
8631
+ * full_location,
8632
+ location,
8633
+ statements. last ( ) . type_ ( ) ,
8634
+ extracted. parameters ,
8635
+ end,
8636
+ )
8529
8637
}
8530
- ExtractedValue :: Statements ( location) => self . extract_statements (
8638
+ ExtractedValue :: Expression ( expression) => self . extract_code_in_tail_position (
8639
+ expression. location ( ) ,
8640
+ expression. location ( ) ,
8641
+ expression. type_ ( ) ,
8642
+ extracted. parameters ,
8643
+ end,
8644
+ ) ,
8645
+ ExtractedValue :: Statements {
8646
+ location,
8647
+ position : StatementPosition :: NotTail ,
8648
+ } => self . extract_statements (
8531
8649
location,
8532
8650
extracted. parameters ,
8533
8651
extracted. returned_variables ,
8534
8652
end,
8535
8653
) ,
8654
+ ExtractedValue :: Statements {
8655
+ location,
8656
+ position : StatementPosition :: Tail { type_ } ,
8657
+ } => self . extract_code_in_tail_position (
8658
+ location,
8659
+ location,
8660
+ type_,
8661
+ extracted. parameters ,
8662
+ end,
8663
+ ) ,
8536
8664
}
8537
8665
8538
8666
let mut action = Vec :: with_capacity ( 1 ) ;
@@ -8561,62 +8689,17 @@ impl<'a> ExtractFunction<'a> {
8561
8689
}
8562
8690
}
8563
8691
8564
- fn extract_expression (
8692
+ /// Extracts code from the end of a function or block. This could either be
8693
+ /// a single expression, or multiple statements followed by a final expression.
8694
+ fn extract_code_in_tail_position (
8565
8695
& mut self ,
8566
- expression : & TypedExpr ,
8696
+ location : SrcSpan ,
8697
+ code_location : SrcSpan ,
8698
+ type_ : Arc < Type > ,
8567
8699
parameters : Vec < ( EcoString , Arc < Type > ) > ,
8568
8700
function_end : u32 ,
8569
8701
) {
8570
- // If we extract a block, it isn't very helpful to have the body of the
8571
- // extracted function just be a single block expression, so instead we
8572
- // extract the statements inside the block. For example, the following
8573
- // code:
8574
- //
8575
- // ```gleam
8576
- // pub fn main() {
8577
- // let x = {
8578
- // // ^ Select from here
8579
- // let a = 1
8580
- // let b = 2
8581
- // a + b
8582
- // }
8583
- // //^ Until here
8584
- // x
8585
- // }
8586
- // ```
8587
- //
8588
- // Would produce the following extracted function:
8589
- //
8590
- // ```gleam
8591
- // fn function() {
8592
- // let a = 1
8593
- // let b = 2
8594
- // a + b
8595
- // }
8596
- // ```
8597
- //
8598
- // Rather than:
8599
- //
8600
- // ```gleam
8601
- // fn function() {
8602
- // {
8603
- // let a = 1
8604
- // let b = 2
8605
- // a + b
8606
- // }
8607
- // }
8608
- // ```
8609
- //
8610
- let extracted_code_location = if let TypedExpr :: Block { statements, .. } = expression {
8611
- SrcSpan :: new (
8612
- statements. first ( ) . location ( ) . start ,
8613
- statements. last ( ) . location ( ) . end ,
8614
- )
8615
- } else {
8616
- expression. location ( )
8617
- } ;
8618
-
8619
- let expression_code = code_at ( self . module , extracted_code_location) ;
8702
+ let expression_code = code_at ( self . module , code_location) ;
8620
8703
8621
8704
let name = self . function_name ( ) ;
8622
8705
let arguments = parameters. iter ( ) . map ( |( name, _) | name) . join ( ", " ) ;
@@ -8626,15 +8709,15 @@ impl<'a> ExtractFunction<'a> {
8626
8709
// it with the call and preserve all other semantics; only one value can
8627
8710
// be returned from the expression, unlike when extracting multiple
8628
8711
// statements.
8629
- self . edits . replace ( expression . location ( ) , call) ;
8712
+ self . edits . replace ( location, call) ;
8630
8713
8631
8714
let mut printer = Printer :: new ( & self . module . ast . names ) ;
8632
8715
8633
8716
let parameters = parameters
8634
8717
. iter ( )
8635
8718
. map ( |( name, type_) | eco_format ! ( "{name}: {}" , printer. print_type( type_) ) )
8636
8719
. join ( ", " ) ;
8637
- let return_type = printer. print_type ( & expression . type_ ( ) ) ;
8720
+ let return_type = printer. print_type ( & type_) ;
8638
8721
8639
8722
let function = format ! (
8640
8723
"\n \n fn {name}({parameters}) -> {return_type} {{
@@ -8861,11 +8944,25 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
8861
8944
8862
8945
if within ( self . params . range , range) {
8863
8946
self . function_end_position = Some ( function. end_position ) ;
8947
+ self . last_statement_location = function. body . last ( ) . map ( |last| last. location ( ) ) ;
8864
8948
8865
8949
ast:: visit:: visit_typed_function ( self , function) ;
8866
8950
}
8867
8951
}
8868
8952
8953
+ fn visit_typed_expr_block (
8954
+ & mut self ,
8955
+ location : & ' ast SrcSpan ,
8956
+ statements : & ' ast [ TypedStatement ] ,
8957
+ ) {
8958
+ let last_statement_location = self . last_statement_location ;
8959
+ self . last_statement_location = statements. last ( ) . map ( |last| last. location ( ) ) ;
8960
+
8961
+ ast:: visit:: visit_typed_expr_block ( self , location, statements) ;
8962
+
8963
+ self . last_statement_location = last_statement_location;
8964
+ }
8965
+
8869
8966
fn visit_typed_expr ( & mut self , expression : & ' ast TypedExpr ) {
8870
8967
// If we have already determined what code we want to extract, we don't
8871
8968
// want to extract this instead. This expression would be inside the
@@ -8884,12 +8981,24 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
8884
8981
}
8885
8982
8886
8983
fn visit_typed_statement ( & mut self , statement : & ' ast TypedStatement ) {
8887
- if self . can_extract ( statement. location ( ) ) {
8984
+ let location = statement. location ( ) ;
8985
+ if self . can_extract ( location) {
8986
+ let position = if let Some ( last_statement_location) = self . last_statement_location
8987
+ && location == last_statement_location
8988
+ {
8989
+ StatementPosition :: Tail {
8990
+ type_ : statement. type_ ( ) ,
8991
+ }
8992
+ } else {
8993
+ StatementPosition :: NotTail
8994
+ } ;
8995
+
8888
8996
match & mut self . function {
8889
8997
None => {
8890
- self . function = Some ( ExtractedFunction :: new ( ExtractedValue :: Statements (
8891
- statement. location ( ) ,
8892
- ) ) ) ;
8998
+ self . function = Some ( ExtractedFunction :: new ( ExtractedValue :: Statements {
8999
+ location,
9000
+ position,
9001
+ } ) ) ;
8893
9002
}
8894
9003
// If we have already chosen an expression to extract, that means
8895
9004
// that this statement is within the already extracted expression,
@@ -8902,9 +9011,16 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
8902
9011
// be included within list, so we merge th spans to ensure it is
8903
9012
// included.
8904
9013
Some ( ExtractedFunction {
8905
- value : ExtractedValue :: Statements ( location) ,
9014
+ value :
9015
+ ExtractedValue :: Statements {
9016
+ location,
9017
+ position : extracted_position,
9018
+ } ,
8906
9019
..
8907
- } ) => * location = location. merge ( & statement. location ( ) ) ,
9020
+ } ) => {
9021
+ * location = location. merge ( & statement. location ( ) ) ;
9022
+ * extracted_position = position;
9023
+ }
8908
9024
}
8909
9025
}
8910
9026
ast:: visit:: visit_typed_statement ( self , statement) ;
0 commit comments