Skip to content

Commit f0ead7e

Browse files
GearsDatapackslpil
authored andcommitted
Detect which variables are used outside the extracted function
1 parent 2dd3748 commit f0ead7e

7 files changed

+241
-100
lines changed

compiler-core/src/language_server/code_action.rs

Lines changed: 151 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -8417,15 +8417,40 @@ pub struct ExtractFunction<'a> {
84178417
module: &'a Module,
84188418
params: &'a CodeActionParams,
84198419
edits: TextEdits<'a>,
8420-
extract: Extract<'a>,
8420+
extract: Option<ExtractedFunction<'a>>,
84218421
function_end_position: Option<u32>,
84228422
}
84238423

8424+
struct ExtractedFunction<'a> {
8425+
parameters: Vec<(EcoString, Arc<Type>)>,
8426+
returned_variables: Vec<(EcoString, Arc<Type>)>,
8427+
value: ExtractedValue<'a>,
8428+
}
8429+
8430+
impl<'a> ExtractedFunction<'a> {
8431+
fn new(value: ExtractedValue<'a>) -> Self {
8432+
Self {
8433+
value,
8434+
parameters: Vec::new(),
8435+
returned_variables: Vec::new(),
8436+
}
8437+
}
8438+
8439+
fn location(&self) -> SrcSpan {
8440+
match &self.value {
8441+
ExtractedValue::Expression(expression) => expression.location(),
8442+
ExtractedValue::Statements(statements) => SrcSpan::new(
8443+
statements.first().location().start,
8444+
statements.last().location().end,
8445+
),
8446+
}
8447+
}
8448+
}
8449+
84248450
#[derive(Debug)]
8425-
enum Extract<'a> {
8426-
None,
8451+
enum ExtractedValue<'a> {
84278452
Expression(&'a TypedExpr),
8428-
Statements(Vec<&'a TypedStatement>),
8453+
Statements(Vec1<&'a TypedStatement>),
84298454
}
84308455

84318456
impl<'a> ExtractFunction<'a> {
@@ -8438,7 +8463,7 @@ impl<'a> ExtractFunction<'a> {
84388463
module,
84398464
params,
84408465
edits: TextEdits::new(line_numbers),
8441-
extract: Extract::None,
8466+
extract: None,
84428467
function_end_position: None,
84438468
}
84448469
}
@@ -8454,10 +8479,19 @@ impl<'a> ExtractFunction<'a> {
84548479
return Vec::new();
84558480
};
84568481

8457-
match std::mem::replace(&mut self.extract, Extract::None) {
8458-
Extract::None => return Vec::new(),
8459-
Extract::Expression(expression) => self.extract_expression(expression, end),
8460-
Extract::Statements(statements) => self.extract_statements(statements, end),
8482+
let Some(extracted) = self.extract.take() else {
8483+
return Vec::new();
8484+
};
8485+
match extracted.value {
8486+
ExtractedValue::Expression(expression) => {
8487+
self.extract_expression(expression, extracted.parameters, end)
8488+
}
8489+
ExtractedValue::Statements(statements) => self.extract_statements(
8490+
statements,
8491+
extracted.parameters,
8492+
extracted.returned_variables,
8493+
end,
8494+
),
84618495
}
84628496

84638497
let mut action = Vec::with_capacity(1);
@@ -8469,18 +8503,21 @@ impl<'a> ExtractFunction<'a> {
84698503
action
84708504
}
84718505

8472-
fn extract_expression(&mut self, expression: &TypedExpr, function_end: u32) {
8473-
let referenced_variables = referenced_variables(expression);
8474-
8506+
fn extract_expression(
8507+
&mut self,
8508+
expression: &TypedExpr,
8509+
parameters: Vec<(EcoString, Arc<Type>)>,
8510+
function_end: u32,
8511+
) {
84758512
let expression_code = code_at(self.module, expression.location());
84768513

8477-
let arguments = referenced_variables.iter().map(|(name, _)| name).join(", ");
8514+
let arguments = parameters.iter().map(|(name, _)| name).join(", ");
84788515
let call = format!("function({arguments})");
84798516
self.edits.replace(expression.location(), call);
84808517

84818518
let mut printer = Printer::new(&self.module.ast.names);
84828519

8483-
let parameters = referenced_variables
8520+
let parameters = parameters
84848521
.iter()
84858522
.map(|(name, type_)| eco_format!("{name}: {}", printer.print_type(type_)))
84868523
.join(", ");
@@ -8495,40 +8532,97 @@ impl<'a> ExtractFunction<'a> {
84958532
self.edits.insert(function_end, function);
84968533
}
84978534

8498-
fn extract_statements(&mut self, statements: Vec<&TypedStatement>, function_end: u32) {
8499-
let Some(first) = statements.first() else {
8500-
return;
8501-
};
8502-
let Some(last) = statements.last() else {
8503-
return;
8504-
};
8535+
fn extract_statements(
8536+
&mut self,
8537+
statements: Vec1<&TypedStatement>,
8538+
parameters: Vec<(EcoString, Arc<Type>)>,
8539+
returned_variables: Vec<(EcoString, Arc<Type>)>,
8540+
function_end: u32,
8541+
) {
8542+
let first = statements.first();
8543+
let last = statements.last();
85058544

85068545
let location = SrcSpan::new(first.location().start, last.location().end);
85078546

8508-
let referenced_variables = referenced_variables_for_statements(&statements, location);
8509-
85108547
let code = code_at(self.module, location);
85118548

8512-
let arguments = referenced_variables.iter().map(|(name, _)| name).join(", ");
8513-
let call = format!("function({arguments})");
8549+
let returns_anything = !returned_variables.is_empty();
8550+
8551+
let (return_type, return_value) = match returned_variables.as_slice() {
8552+
[] => (type_::nil(), "Nil".into()),
8553+
[(name, type_)] => (type_.clone(), name.clone()),
8554+
_ => {
8555+
let values = returned_variables.iter().map(|(name, _)| name).join(", ");
8556+
let type_ = type_::tuple(
8557+
returned_variables
8558+
.into_iter()
8559+
.map(|(_, type_)| type_)
8560+
.collect(),
8561+
);
8562+
8563+
(type_, eco_format!("#({values})"))
8564+
}
8565+
};
8566+
8567+
let arguments = parameters.iter().map(|(name, _)| name).join(", ");
8568+
8569+
let call = if returns_anything {
8570+
format!("let {return_value} = function({arguments})")
8571+
} else {
8572+
format!("function({arguments})")
8573+
};
85148574
self.edits.replace(location, call);
85158575

85168576
let mut printer = Printer::new(&self.module.ast.names);
85178577

8518-
let parameters = referenced_variables
8578+
let parameters = parameters
85198579
.iter()
85208580
.map(|(name, type_)| eco_format!("{name}: {}", printer.print_type(type_)))
85218581
.join(", ");
8522-
let return_type = printer.print_type(&last.type_());
8582+
8583+
let return_type = printer.print_type(&return_type);
85238584

85248585
let function = format!(
85258586
"\n\nfn function({parameters}) -> {return_type} {{
85268587
{code}
8588+
{return_value}
85278589
}}"
85288590
);
85298591

85308592
self.edits.insert(function_end, function);
85318593
}
8594+
8595+
fn register_referenced_variable(
8596+
&mut self,
8597+
name: &EcoString,
8598+
type_: &Arc<Type>,
8599+
location: SrcSpan,
8600+
definition_location: SrcSpan,
8601+
) {
8602+
let Some(extracted) = &mut self.extract else {
8603+
return;
8604+
};
8605+
8606+
let extracted_location = extracted.location();
8607+
8608+
let variables = if extracted_location.contains_span(location)
8609+
&& !extracted_location.contains_span(definition_location)
8610+
{
8611+
&mut extracted.parameters
8612+
} else if extracted_location.contains_span(definition_location)
8613+
&& !extracted_location.contains_span(location)
8614+
{
8615+
&mut extracted.returned_variables
8616+
} else {
8617+
return;
8618+
};
8619+
8620+
if variables.iter().any(|(variable, _)| variable == name) {
8621+
return;
8622+
}
8623+
8624+
variables.push((name.clone(), type_.clone()));
8625+
}
85328626
}
85338627

85348628
impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
@@ -8543,16 +8637,14 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
85438637
}
85448638

85458639
fn visit_typed_expr(&mut self, expression: &'ast TypedExpr) {
8546-
match &self.extract {
8547-
Extract::None => {
8548-
let range = self.edits.src_span_to_lsp_range(expression.location());
8640+
if self.extract.is_none() {
8641+
let range = self.edits.src_span_to_lsp_range(expression.location());
85498642

8550-
if within(range, self.params.range) {
8551-
self.extract = Extract::Expression(expression);
8552-
return;
8553-
}
8643+
if within(range, self.params.range) {
8644+
self.extract = Some(ExtractedFunction::new(ExtractedValue::Expression(
8645+
expression,
8646+
)));
85548647
}
8555-
Extract::Expression(_) | Extract::Statements(_) => {}
85568648
}
85578649
ast::visit::visit_typed_expr(self, expression);
85588650
}
@@ -8561,81 +8653,43 @@ impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
85618653
let range = self.edits.src_span_to_lsp_range(statement.location());
85628654
if within(range, self.params.range) {
85638655
match &mut self.extract {
8564-
Extract::None => {
8565-
self.extract = Extract::Statements(vec![statement]);
8656+
None => {
8657+
self.extract = Some(ExtractedFunction::new(ExtractedValue::Statements(vec1![
8658+
statement,
8659+
])));
85668660
}
8567-
Extract::Expression(expression) => {
8568-
if expression.location().contains_span(statement.location()) {
8569-
return;
8570-
}
8571-
8572-
self.extract = Extract::Statements(vec![statement]);
8573-
}
8574-
Extract::Statements(statements) => {
8661+
Some(ExtractedFunction {
8662+
value: ExtractedValue::Expression(_),
8663+
..
8664+
}) => {}
8665+
Some(ExtractedFunction {
8666+
value: ExtractedValue::Statements(statements),
8667+
..
8668+
}) => {
85758669
statements.push(statement);
85768670
}
85778671
}
8578-
} else {
8579-
ast::visit::visit_typed_statement(self, statement);
8580-
}
8581-
}
8582-
}
8583-
8584-
fn referenced_variables(expression: &TypedExpr) -> Vec<(EcoString, Arc<Type>)> {
8585-
let mut references = ReferencedVariables::new(expression.location());
8586-
references.visit_typed_expr(expression);
8587-
references.variables
8588-
}
8589-
8590-
fn referenced_variables_for_statements(
8591-
statements: &[&TypedStatement],
8592-
location: SrcSpan,
8593-
) -> Vec<(EcoString, Arc<Type>)> {
8594-
let mut references = ReferencedVariables::new(location);
8595-
for statement in statements {
8596-
references.visit_typed_statement(*statement);
8597-
}
8598-
references.variables
8599-
}
8600-
8601-
struct ReferencedVariables {
8602-
variables: Vec<(EcoString, Arc<Type>)>,
8603-
location: SrcSpan,
8604-
}
8605-
8606-
impl ReferencedVariables {
8607-
fn new(location: SrcSpan) -> Self {
8608-
Self {
8609-
variables: Vec::new(),
8610-
location,
8611-
}
8612-
}
8613-
8614-
fn register(&mut self, name: &EcoString, type_: &Arc<Type>, definition_location: SrcSpan) {
8615-
if self.location.contains_span(definition_location) {
8616-
return;
8617-
}
8618-
8619-
if !self
8620-
.variables
8621-
.iter()
8622-
.any(|(variable_name, _)| variable_name == name)
8623-
{
8624-
self.variables.push((name.clone(), type_.clone()))
86258672
}
8673+
ast::visit::visit_typed_statement(self, statement);
86268674
}
8627-
}
86288675

8629-
impl<'ast> ast::visit::Visit<'ast> for ReferencedVariables {
86308676
fn visit_typed_expr_var(
86318677
&mut self,
8632-
_location: &'ast SrcSpan,
8678+
location: &'ast SrcSpan,
86338679
constructor: &'ast ValueConstructor,
86348680
name: &'ast EcoString,
86358681
) {
86368682
match &constructor.variant {
8637-
type_::ValueConstructorVariant::LocalVariable { location, .. } => {
8638-
self.register(name, &constructor.type_, *location);
8683+
type_::ValueConstructorVariant::LocalVariable {
8684+
location: definition_location,
8685+
..
8686+
} => {
8687+
self.register_referenced_variable(
8688+
name,
8689+
&constructor.type_,
8690+
*location,
8691+
*definition_location,
8692+
);
86398693
}
86408694
type_::ValueConstructorVariant::ModuleConstant { .. }
86418695
| type_::ValueConstructorVariant::LocalConstant { .. }

compiler-core/src/language_server/tests/action.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10299,3 +10299,33 @@ pub fn do_things(a, b) {
1029910299
find_position_of("let").select_until(find_position_of("+ a * b\n").under_char('\n'))
1030010300
);
1030110301
}
10302+
10303+
#[test]
10304+
fn extract_function_which_uses_multiple_extracted_variables() {
10305+
assert_code_action!(
10306+
EXTRACT_FUNCTION,
10307+
"
10308+
pub fn do_things(a, b) {
10309+
let wibble = a + b
10310+
let wobble = a * b
10311+
wobble / wibble
10312+
}
10313+
",
10314+
find_position_of("let").select_until(find_position_of("* b\n").under_char('\n'))
10315+
);
10316+
}
10317+
10318+
#[test]
10319+
fn extract_function_which_uses_no_extracted_variables() {
10320+
assert_code_action!(
10321+
EXTRACT_FUNCTION,
10322+
"
10323+
pub fn do_things(a, b) {
10324+
let x = a + b
10325+
echo x
10326+
a
10327+
}
10328+
",
10329+
find_position_of("let").select_until(find_position_of("echo x\n").under_char('\n'))
10330+
);
10331+
}

compiler-core/src/language_server/tests/snapshots/gleam_core__language_server__tests__action__extract_function_from_statements.snap

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ pub fn do_things(a, b) {
1818
----- AFTER ACTION
1919

2020
pub fn do_things(a, b) {
21-
function(a, b)
21+
let result = function(a, b)
2222
result + 3
2323
}
2424

2525
fn function(a: Int, b: Int) -> Int {
2626
let a = 10 + a
2727
let b = 10 + b
2828
let result = a * b
29+
result
2930
}

0 commit comments

Comments
 (0)