@@ -8417,15 +8417,40 @@ pub struct ExtractFunction<'a> {
8417
8417
module : & ' a Module ,
8418
8418
params : & ' a CodeActionParams ,
8419
8419
edits : TextEdits < ' a > ,
8420
- extract : Extract < ' a > ,
8420
+ extract : Option < ExtractedFunction < ' a > > ,
8421
8421
function_end_position : Option < u32 > ,
8422
8422
}
8423
8423
8424
+ struct ExtractedFunction < ' a > {
8425
+ parameters : Vec < ( EcoString , Arc < Type > ) > ,
8426
+ returned_variables : Vec < ( EcoString , Arc < Type > ) > ,
8427
+ value : ExtractedValue < ' a > ,
8428
+ }
8429
+
8430
+ impl < ' a > ExtractedFunction < ' a > {
8431
+ fn new ( value : ExtractedValue < ' a > ) -> Self {
8432
+ Self {
8433
+ value,
8434
+ parameters : Vec :: new ( ) ,
8435
+ returned_variables : Vec :: new ( ) ,
8436
+ }
8437
+ }
8438
+
8439
+ fn location ( & self ) -> SrcSpan {
8440
+ match & self . value {
8441
+ ExtractedValue :: Expression ( expression) => expression. location ( ) ,
8442
+ ExtractedValue :: Statements ( statements) => SrcSpan :: new (
8443
+ statements. first ( ) . location ( ) . start ,
8444
+ statements. last ( ) . location ( ) . end ,
8445
+ ) ,
8446
+ }
8447
+ }
8448
+ }
8449
+
8424
8450
#[ derive( Debug ) ]
8425
- enum Extract < ' a > {
8426
- None ,
8451
+ enum ExtractedValue < ' a > {
8427
8452
Expression ( & ' a TypedExpr ) ,
8428
- Statements ( Vec < & ' a TypedStatement > ) ,
8453
+ Statements ( Vec1 < & ' a TypedStatement > ) ,
8429
8454
}
8430
8455
8431
8456
impl < ' a > ExtractFunction < ' a > {
@@ -8438,7 +8463,7 @@ impl<'a> ExtractFunction<'a> {
8438
8463
module,
8439
8464
params,
8440
8465
edits : TextEdits :: new ( line_numbers) ,
8441
- extract : Extract :: None ,
8466
+ extract : None ,
8442
8467
function_end_position : None ,
8443
8468
}
8444
8469
}
@@ -8454,10 +8479,19 @@ impl<'a> ExtractFunction<'a> {
8454
8479
return Vec :: new ( ) ;
8455
8480
} ;
8456
8481
8457
- match std:: mem:: replace ( & mut self . extract , Extract :: None ) {
8458
- Extract :: None => return Vec :: new ( ) ,
8459
- Extract :: Expression ( expression) => self . extract_expression ( expression, end) ,
8460
- Extract :: Statements ( statements) => self . extract_statements ( statements, end) ,
8482
+ let Some ( extracted) = self . extract . take ( ) else {
8483
+ return Vec :: new ( ) ;
8484
+ } ;
8485
+ match extracted. value {
8486
+ ExtractedValue :: Expression ( expression) => {
8487
+ self . extract_expression ( expression, extracted. parameters , end)
8488
+ }
8489
+ ExtractedValue :: Statements ( statements) => self . extract_statements (
8490
+ statements,
8491
+ extracted. parameters ,
8492
+ extracted. returned_variables ,
8493
+ end,
8494
+ ) ,
8461
8495
}
8462
8496
8463
8497
let mut action = Vec :: with_capacity ( 1 ) ;
@@ -8469,18 +8503,21 @@ impl<'a> ExtractFunction<'a> {
8469
8503
action
8470
8504
}
8471
8505
8472
- fn extract_expression ( & mut self , expression : & TypedExpr , function_end : u32 ) {
8473
- let referenced_variables = referenced_variables ( expression) ;
8474
-
8506
+ fn extract_expression (
8507
+ & mut self ,
8508
+ expression : & TypedExpr ,
8509
+ parameters : Vec < ( EcoString , Arc < Type > ) > ,
8510
+ function_end : u32 ,
8511
+ ) {
8475
8512
let expression_code = code_at ( self . module , expression. location ( ) ) ;
8476
8513
8477
- let arguments = referenced_variables . iter ( ) . map ( |( name, _) | name) . join ( ", " ) ;
8514
+ let arguments = parameters . iter ( ) . map ( |( name, _) | name) . join ( ", " ) ;
8478
8515
let call = format ! ( "function({arguments})" ) ;
8479
8516
self . edits . replace ( expression. location ( ) , call) ;
8480
8517
8481
8518
let mut printer = Printer :: new ( & self . module . ast . names ) ;
8482
8519
8483
- let parameters = referenced_variables
8520
+ let parameters = parameters
8484
8521
. iter ( )
8485
8522
. map ( |( name, type_) | eco_format ! ( "{name}: {}" , printer. print_type( type_) ) )
8486
8523
. join ( ", " ) ;
@@ -8495,40 +8532,97 @@ impl<'a> ExtractFunction<'a> {
8495
8532
self . edits . insert ( function_end, function) ;
8496
8533
}
8497
8534
8498
- fn extract_statements ( & mut self , statements : Vec < & TypedStatement > , function_end : u32 ) {
8499
- let Some ( first) = statements. first ( ) else {
8500
- return ;
8501
- } ;
8502
- let Some ( last) = statements. last ( ) else {
8503
- return ;
8504
- } ;
8535
+ fn extract_statements (
8536
+ & mut self ,
8537
+ statements : Vec1 < & TypedStatement > ,
8538
+ parameters : Vec < ( EcoString , Arc < Type > ) > ,
8539
+ returned_variables : Vec < ( EcoString , Arc < Type > ) > ,
8540
+ function_end : u32 ,
8541
+ ) {
8542
+ let first = statements. first ( ) ;
8543
+ let last = statements. last ( ) ;
8505
8544
8506
8545
let location = SrcSpan :: new ( first. location ( ) . start , last. location ( ) . end ) ;
8507
8546
8508
- let referenced_variables = referenced_variables_for_statements ( & statements, location) ;
8509
-
8510
8547
let code = code_at ( self . module , location) ;
8511
8548
8512
- let arguments = referenced_variables. iter ( ) . map ( |( name, _) | name) . join ( ", " ) ;
8513
- let call = format ! ( "function({arguments})" ) ;
8549
+ let returns_anything = !returned_variables. is_empty ( ) ;
8550
+
8551
+ let ( return_type, return_value) = match returned_variables. as_slice ( ) {
8552
+ [ ] => ( type_:: nil ( ) , "Nil" . into ( ) ) ,
8553
+ [ ( name, type_) ] => ( type_. clone ( ) , name. clone ( ) ) ,
8554
+ _ => {
8555
+ let values = returned_variables. iter ( ) . map ( |( name, _) | name) . join ( ", " ) ;
8556
+ let type_ = type_:: tuple (
8557
+ returned_variables
8558
+ . into_iter ( )
8559
+ . map ( |( _, type_) | type_)
8560
+ . collect ( ) ,
8561
+ ) ;
8562
+
8563
+ ( type_, eco_format ! ( "#({values})" ) )
8564
+ }
8565
+ } ;
8566
+
8567
+ let arguments = parameters. iter ( ) . map ( |( name, _) | name) . join ( ", " ) ;
8568
+
8569
+ let call = if returns_anything {
8570
+ format ! ( "let {return_value} = function({arguments})" )
8571
+ } else {
8572
+ format ! ( "function({arguments})" )
8573
+ } ;
8514
8574
self . edits . replace ( location, call) ;
8515
8575
8516
8576
let mut printer = Printer :: new ( & self . module . ast . names ) ;
8517
8577
8518
- let parameters = referenced_variables
8578
+ let parameters = parameters
8519
8579
. iter ( )
8520
8580
. map ( |( name, type_) | eco_format ! ( "{name}: {}" , printer. print_type( type_) ) )
8521
8581
. join ( ", " ) ;
8522
- let return_type = printer. print_type ( & last. type_ ( ) ) ;
8582
+
8583
+ let return_type = printer. print_type ( & return_type) ;
8523
8584
8524
8585
let function = format ! (
8525
8586
"\n \n fn function({parameters}) -> {return_type} {{
8526
8587
{code}
8588
+ {return_value}
8527
8589
}}"
8528
8590
) ;
8529
8591
8530
8592
self . edits . insert ( function_end, function) ;
8531
8593
}
8594
+
8595
+ fn register_referenced_variable (
8596
+ & mut self ,
8597
+ name : & EcoString ,
8598
+ type_ : & Arc < Type > ,
8599
+ location : SrcSpan ,
8600
+ definition_location : SrcSpan ,
8601
+ ) {
8602
+ let Some ( extracted) = & mut self . extract else {
8603
+ return ;
8604
+ } ;
8605
+
8606
+ let extracted_location = extracted. location ( ) ;
8607
+
8608
+ let variables = if extracted_location. contains_span ( location)
8609
+ && !extracted_location. contains_span ( definition_location)
8610
+ {
8611
+ & mut extracted. parameters
8612
+ } else if extracted_location. contains_span ( definition_location)
8613
+ && !extracted_location. contains_span ( location)
8614
+ {
8615
+ & mut extracted. returned_variables
8616
+ } else {
8617
+ return ;
8618
+ } ;
8619
+
8620
+ if variables. iter ( ) . any ( |( variable, _) | variable == name) {
8621
+ return ;
8622
+ }
8623
+
8624
+ variables. push ( ( name. clone ( ) , type_. clone ( ) ) ) ;
8625
+ }
8532
8626
}
8533
8627
8534
8628
impl < ' ast > ast:: visit:: Visit < ' ast > for ExtractFunction < ' ast > {
@@ -8543,16 +8637,14 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
8543
8637
}
8544
8638
8545
8639
fn visit_typed_expr ( & mut self , expression : & ' ast TypedExpr ) {
8546
- match & self . extract {
8547
- Extract :: None => {
8548
- let range = self . edits . src_span_to_lsp_range ( expression. location ( ) ) ;
8640
+ if self . extract . is_none ( ) {
8641
+ let range = self . edits . src_span_to_lsp_range ( expression. location ( ) ) ;
8549
8642
8550
- if within ( range, self . params . range ) {
8551
- self . extract = Extract :: Expression ( expression ) ;
8552
- return ;
8553
- }
8643
+ if within ( range, self . params . range ) {
8644
+ self . extract = Some ( ExtractedFunction :: new ( ExtractedValue :: Expression (
8645
+ expression ,
8646
+ ) ) ) ;
8554
8647
}
8555
- Extract :: Expression ( _) | Extract :: Statements ( _) => { }
8556
8648
}
8557
8649
ast:: visit:: visit_typed_expr ( self , expression) ;
8558
8650
}
@@ -8561,81 +8653,43 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
8561
8653
let range = self . edits . src_span_to_lsp_range ( statement. location ( ) ) ;
8562
8654
if within ( range, self . params . range ) {
8563
8655
match & mut self . extract {
8564
- Extract :: None => {
8565
- self . extract = Extract :: Statements ( vec ! [ statement] ) ;
8656
+ None => {
8657
+ self . extract = Some ( ExtractedFunction :: new ( ExtractedValue :: Statements ( vec1 ! [
8658
+ statement,
8659
+ ] ) ) ) ;
8566
8660
}
8567
- Extract :: Expression ( expression ) => {
8568
- if expression . location ( ) . contains_span ( statement . location ( ) ) {
8569
- return ;
8570
- }
8571
-
8572
- self . extract = Extract :: Statements ( vec ! [ statement ] ) ;
8573
- }
8574
- Extract :: Statements ( statements ) => {
8661
+ Some ( ExtractedFunction {
8662
+ value : ExtractedValue :: Expression ( _ ) ,
8663
+ ..
8664
+ } ) => { }
8665
+ Some ( ExtractedFunction {
8666
+ value : ExtractedValue :: Statements ( statements ) ,
8667
+ ..
8668
+ } ) => {
8575
8669
statements. push ( statement) ;
8576
8670
}
8577
8671
}
8578
- } else {
8579
- ast:: visit:: visit_typed_statement ( self , statement) ;
8580
- }
8581
- }
8582
- }
8583
-
8584
- fn referenced_variables ( expression : & TypedExpr ) -> Vec < ( EcoString , Arc < Type > ) > {
8585
- let mut references = ReferencedVariables :: new ( expression. location ( ) ) ;
8586
- references. visit_typed_expr ( expression) ;
8587
- references. variables
8588
- }
8589
-
8590
- fn referenced_variables_for_statements (
8591
- statements : & [ & TypedStatement ] ,
8592
- location : SrcSpan ,
8593
- ) -> Vec < ( EcoString , Arc < Type > ) > {
8594
- let mut references = ReferencedVariables :: new ( location) ;
8595
- for statement in statements {
8596
- references. visit_typed_statement ( * statement) ;
8597
- }
8598
- references. variables
8599
- }
8600
-
8601
- struct ReferencedVariables {
8602
- variables : Vec < ( EcoString , Arc < Type > ) > ,
8603
- location : SrcSpan ,
8604
- }
8605
-
8606
- impl ReferencedVariables {
8607
- fn new ( location : SrcSpan ) -> Self {
8608
- Self {
8609
- variables : Vec :: new ( ) ,
8610
- location,
8611
- }
8612
- }
8613
-
8614
- fn register ( & mut self , name : & EcoString , type_ : & Arc < Type > , definition_location : SrcSpan ) {
8615
- if self . location . contains_span ( definition_location) {
8616
- return ;
8617
- }
8618
-
8619
- if !self
8620
- . variables
8621
- . iter ( )
8622
- . any ( |( variable_name, _) | variable_name == name)
8623
- {
8624
- self . variables . push ( ( name. clone ( ) , type_. clone ( ) ) )
8625
8672
}
8673
+ ast:: visit:: visit_typed_statement ( self , statement) ;
8626
8674
}
8627
- }
8628
8675
8629
- impl < ' ast > ast:: visit:: Visit < ' ast > for ReferencedVariables {
8630
8676
fn visit_typed_expr_var (
8631
8677
& mut self ,
8632
- _location : & ' ast SrcSpan ,
8678
+ location : & ' ast SrcSpan ,
8633
8679
constructor : & ' ast ValueConstructor ,
8634
8680
name : & ' ast EcoString ,
8635
8681
) {
8636
8682
match & constructor. variant {
8637
- type_:: ValueConstructorVariant :: LocalVariable { location, .. } => {
8638
- self . register ( name, & constructor. type_ , * location) ;
8683
+ type_:: ValueConstructorVariant :: LocalVariable {
8684
+ location : definition_location,
8685
+ ..
8686
+ } => {
8687
+ self . register_referenced_variable (
8688
+ name,
8689
+ & constructor. type_ ,
8690
+ * location,
8691
+ * definition_location,
8692
+ ) ;
8639
8693
}
8640
8694
type_:: ValueConstructorVariant :: ModuleConstant { .. }
8641
8695
| type_:: ValueConstructorVariant :: LocalConstant { .. }
0 commit comments