Skip to content

Commit b6398fd

Browse files
GearsDatapackslpil
authored andcommitted
Implement extract function code action for expressions
1 parent 5e119f8 commit b6398fd

File tree

5 files changed

+250
-2
lines changed

5 files changed

+250
-2
lines changed

compiler-core/src/ast.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,6 +2276,10 @@ impl SrcSpan {
22762276
byte_index >= self.start && byte_index <= self.end
22772277
}
22782278

2279+
pub fn contains_span(&self, span: SrcSpan) -> bool {
2280+
self.contains(span.start) && self.contains(span.end)
2281+
}
2282+
22792283
/// Merges two spans into a new one that starts at the start of the smaller
22802284
/// one and ends at the end of the bigger one. For example:
22812285
///

compiler-core/src/language_server/code_action.rs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8411,3 +8411,190 @@ impl<'ast> ast::visit::Visit<'ast> for AddOmittedLabels<'ast> {
84118411
self.arguments_and_omitted_labels = Some(omitted_labels);
84128412
}
84138413
}
8414+
8415+
/// Code action to extract selected code into a separate function.
8416+
pub struct ExtractFunction<'a> {
8417+
module: &'a Module,
8418+
params: &'a CodeActionParams,
8419+
edits: TextEdits<'a>,
8420+
extract: Extract<'a>,
8421+
function_end_position: Option<u32>,
8422+
}
8423+
8424+
#[derive(Debug)]
8425+
enum Extract<'a> {
8426+
None,
8427+
Expression(&'a TypedExpr),
8428+
Statements(Vec<&'a TypedStatement>),
8429+
}
8430+
8431+
impl<'a> ExtractFunction<'a> {
8432+
pub fn new(
8433+
module: &'a Module,
8434+
line_numbers: &'a LineNumbers,
8435+
params: &'a CodeActionParams,
8436+
) -> Self {
8437+
Self {
8438+
module,
8439+
params,
8440+
edits: TextEdits::new(line_numbers),
8441+
extract: Extract::None,
8442+
function_end_position: None,
8443+
}
8444+
}
8445+
8446+
pub fn code_actions(mut self) -> Vec<CodeAction> {
8447+
if self.params.range.start == self.params.range.end {
8448+
return Vec::new();
8449+
}
8450+
8451+
self.visit_typed_module(&self.module.ast);
8452+
8453+
let Some(end) = self.function_end_position else {
8454+
return Vec::new();
8455+
};
8456+
8457+
match std::mem::replace(&mut self.extract, Extract::None) {
8458+
Extract::None => return Vec::new(),
8459+
Extract::Expression(expression) => self.extract_expression(expression, end),
8460+
Extract::Statements(statements) => self.extract_statements(statements, end),
8461+
}
8462+
8463+
let mut action = Vec::with_capacity(1);
8464+
CodeActionBuilder::new("Extract function")
8465+
.kind(CodeActionKind::REFACTOR_EXTRACT)
8466+
.changes(self.params.text_document.uri.clone(), self.edits.edits)
8467+
.preferred(false)
8468+
.push_to(&mut action);
8469+
action
8470+
}
8471+
8472+
fn extract_expression(&mut self, expression: &TypedExpr, function_end: u32) {
8473+
let referenced_variables = referenced_variables(expression);
8474+
8475+
let expression_code = code_at(self.module, expression.location());
8476+
8477+
let arguments = referenced_variables.iter().map(|(name, _)| name).join(", ");
8478+
let call = format!("function({arguments})");
8479+
self.edits.replace(expression.location(), call);
8480+
8481+
let mut printer = Printer::new(&self.module.ast.names);
8482+
8483+
let parameters = referenced_variables
8484+
.iter()
8485+
.map(|(name, type_)| eco_format!("{name}: {}", printer.print_type(type_)))
8486+
.join(", ");
8487+
let return_type = printer.print_type(&expression.type_());
8488+
8489+
let function = format!(
8490+
"\n\nfn function({parameters}) -> {return_type} {{
8491+
{expression_code}
8492+
}}"
8493+
);
8494+
8495+
self.edits.insert(function_end, function);
8496+
}
8497+
8498+
fn extract_statements(&mut self, statements: Vec<&TypedStatement>, function_end: u32) {
8499+
todo!("Implement for statements")
8500+
}
8501+
}
8502+
8503+
impl<'ast> ast::visit::Visit<'ast> for ExtractFunction<'ast> {
8504+
fn visit_typed_function(&mut self, function: &'ast ast::TypedFunction) {
8505+
let range = self.edits.src_span_to_lsp_range(function.full_location());
8506+
8507+
if within(self.params.range, range) {
8508+
self.function_end_position = Some(function.end_position);
8509+
8510+
ast::visit::visit_typed_function(self, function);
8511+
}
8512+
}
8513+
8514+
fn visit_typed_expr(&mut self, expression: &'ast TypedExpr) {
8515+
match &self.extract {
8516+
Extract::None => {
8517+
let range = self.edits.src_span_to_lsp_range(expression.location());
8518+
8519+
if within(range, self.params.range) {
8520+
self.extract = Extract::Expression(expression);
8521+
return;
8522+
}
8523+
}
8524+
Extract::Expression(_) | Extract::Statements(_) => {}
8525+
}
8526+
ast::visit::visit_typed_expr(self, expression);
8527+
}
8528+
8529+
fn visit_typed_statement(&mut self, statement: &'ast TypedStatement) {
8530+
let range = self.edits.src_span_to_lsp_range(statement.location());
8531+
if within(range, self.params.range) {
8532+
match &mut self.extract {
8533+
Extract::None => {
8534+
self.extract = Extract::Statements(vec![statement]);
8535+
}
8536+
Extract::Expression(expression) => {
8537+
if expression.location().contains_span(statement.location()) {
8538+
return;
8539+
}
8540+
8541+
self.extract = Extract::Statements(vec![statement]);
8542+
}
8543+
Extract::Statements(statements) => {
8544+
statements.push(statement);
8545+
}
8546+
}
8547+
} else {
8548+
ast::visit::visit_typed_statement(self, statement);
8549+
}
8550+
}
8551+
}
8552+
8553+
fn referenced_variables(expression: &TypedExpr) -> Vec<(EcoString, Arc<Type>)> {
8554+
let mut references = ReferencedVariables {
8555+
variables: Vec::new(),
8556+
defined_variables: HashSet::new(),
8557+
};
8558+
references.visit_typed_expr(expression);
8559+
references.variables
8560+
}
8561+
8562+
struct ReferencedVariables {
8563+
variables: Vec<(EcoString, Arc<Type>)>,
8564+
defined_variables: HashSet<EcoString>,
8565+
}
8566+
8567+
impl ReferencedVariables {
8568+
fn register(&mut self, name: &EcoString, type_: &Arc<Type>) {
8569+
if self.defined_variables.contains(name) {
8570+
return;
8571+
}
8572+
8573+
if !self
8574+
.variables
8575+
.iter()
8576+
.any(|(variable_name, _)| variable_name == name)
8577+
{
8578+
self.variables.push((name.clone(), type_.clone()))
8579+
}
8580+
}
8581+
}
8582+
8583+
impl<'ast> ast::visit::Visit<'ast> for ReferencedVariables {
8584+
fn visit_typed_expr_var(
8585+
&mut self,
8586+
_location: &'ast SrcSpan,
8587+
constructor: &'ast ValueConstructor,
8588+
name: &'ast EcoString,
8589+
) {
8590+
match &constructor.variant {
8591+
type_::ValueConstructorVariant::LocalVariable { .. } => {
8592+
self.register(name, &constructor.type_);
8593+
}
8594+
type_::ValueConstructorVariant::ModuleConstant { .. }
8595+
| type_::ValueConstructorVariant::LocalConstant { .. }
8596+
| type_::ValueConstructorVariant::ModuleFn { .. }
8597+
| type_::ValueConstructorVariant::Record { .. } => {}
8598+
}
8599+
}
8600+
}

compiler-core/src/language_server/engine.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use crate::{
1313
io::{BeamCompiler, CommandExecutor, FileSystemReader, FileSystemWriter},
1414
language_server::{
1515
code_action::{
16-
AddOmittedLabels, CollapseNestedCase, RemoveBlock, RemovePrivateOpaque,
17-
RemoveUnreachableBranches,
16+
AddOmittedLabels, CollapseNestedCase, ExtractFunction, RemoveBlock,
17+
RemovePrivateOpaque, RemoveUnreachableBranches,
1818
},
1919
compiler::LspProjectCompiler,
2020
files::FileSystemProxy,
@@ -450,6 +450,7 @@ where
450450
actions.extend(WrapInBlock::new(module, &lines, &params).code_actions());
451451
actions.extend(RemoveBlock::new(module, &lines, &params).code_actions());
452452
actions.extend(RemovePrivateOpaque::new(module, &lines, &params).code_actions());
453+
actions.extend(ExtractFunction::new(module, &lines, &params).code_actions());
453454
GenerateDynamicDecoder::new(module, &lines, &params, &mut actions).code_actions();
454455
GenerateJsonEncoder::new(
455456
module,

compiler-core/src/language_server/tests/action.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ const REMOVE_OPAQUE_FROM_PRIVATE_TYPE: &str = "Remove opaque from private type";
135135
const COLLAPSE_NESTED_CASE: &str = "Collapse nested case";
136136
const REMOVE_UNREACHABLE_BRANCHES: &str = "Remove unreachable branches";
137137
const ADD_OMITTED_LABELS: &str = "Add omitted labels";
138+
const EXTRACT_FUNCTION: &str = "Extract function";
138139

139140
macro_rules! assert_code_action {
140141
($title:expr, $code:literal, $range:expr $(,)?) => {
@@ -10227,3 +10228,23 @@ pub fn labelled(a, b) { todo }
1022710228
find_position_of("labelled").to_selection(),
1022810229
);
1022910230
}
10231+
10232+
#[test]
10233+
fn extract_function() {
10234+
assert_code_action!(
10235+
EXTRACT_FUNCTION,
10236+
"
10237+
pub fn do_things(a, b) {
10238+
let result = {
10239+
let a = 10 + a
10240+
let b = 10 + b
10241+
a * b
10242+
}
10243+
result + 3
10244+
}
10245+
",
10246+
find_position_of("{")
10247+
.nth_occurrence(2)
10248+
.select_until(find_position_of("}\n").under_char('\n'))
10249+
);
10250+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
---
2+
source: compiler-core/src/language_server/tests/action.rs
3+
expression: "\npub fn do_things(a, b) {\n let result = {\n let a = 10 + a\n let b = 10 + b\n a * b\n }\n result + 3\n}\n"
4+
---
5+
----- BEFORE ACTION
6+
7+
pub fn do_things(a, b) {
8+
let result = {
9+
10+
let a = 10 + a
11+
▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔
12+
let b = 10 + b
13+
▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔
14+
a * b
15+
▔▔▔▔▔▔▔▔▔
16+
}
17+
▔▔▔
18+
result + 3
19+
}
20+
21+
22+
----- AFTER ACTION
23+
24+
pub fn do_things(a, b) {
25+
let result = function(a, b)
26+
result + 3
27+
}
28+
29+
fn function(a: Int, b: Int) -> Int {
30+
{
31+
let a = 10 + a
32+
let b = 10 + b
33+
a * b
34+
}
35+
}

0 commit comments

Comments
 (0)