Skip to content

Commit f102734

Browse files
committed
feat: add 'Annotate all top level type definitions' code action
1 parent be19e79 commit f102734

File tree

2 files changed

+179
-3
lines changed

2 files changed

+179
-3
lines changed

compiler-core/src/language_server/code_action.rs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,179 @@ impl<'a> AddAnnotations<'a> {
13601360
}
13611361
}
13621362

1363+
/// Code action to add type annotations to all top level definitions
1364+
///
1365+
pub struct AnnotateTopLevelTypeDefinitions<'a> {
1366+
module: &'a Module,
1367+
params: &'a CodeActionParams,
1368+
edits: TextEdits<'a>,
1369+
printer: Printer<'a>,
1370+
is_hovering_definition: bool,
1371+
}
1372+
1373+
impl<'a> AnnotateTopLevelTypeDefinitions<'a> {
1374+
pub fn new(
1375+
module: &'a Module,
1376+
line_numbers: &'a LineNumbers,
1377+
params: &'a CodeActionParams,
1378+
) -> Self {
1379+
Self {
1380+
module,
1381+
params,
1382+
edits: TextEdits::new(line_numbers),
1383+
// We need to use the same printer for all the edits because otherwise
1384+
// we could get duplicate type variable names.
1385+
printer: Printer::new_without_type_variables(&module.ast.names),
1386+
is_hovering_definition: false,
1387+
}
1388+
}
1389+
1390+
pub fn code_actions(mut self) -> Vec<CodeAction> {
1391+
self.visit_typed_module(&self.module.ast);
1392+
1393+
// We only want to trigger the action if we're over one of the definition in
1394+
// the module
1395+
if !self.is_hovering_definition {
1396+
return vec![];
1397+
};
1398+
1399+
let mut action = Vec::with_capacity(1);
1400+
CodeActionBuilder::new("Annotate all top level type definitions")
1401+
.kind(CodeActionKind::REFACTOR_REWRITE)
1402+
.changes(self.params.text_document.uri.clone(), self.edits.edits)
1403+
.preferred(false)
1404+
.push_to(&mut action);
1405+
action
1406+
}
1407+
}
1408+
1409+
impl<'ast> ast::visit::Visit<'ast> for AnnotateTopLevelTypeDefinitions<'_> {
1410+
fn visit_typed_module_constant(&mut self, constant: &'ast TypedModuleConstant) {
1411+
// Since type variable names are local to definitions, any type variables
1412+
// in other parts of the module shouldn't affect what we print for the
1413+
// annotations of this constant.
1414+
self.printer.clear_type_variables();
1415+
1416+
let code_action_range = self.edits.src_span_to_lsp_range(constant.location);
1417+
1418+
if overlaps(code_action_range, self.params.range) {
1419+
self.is_hovering_definition = true;
1420+
}
1421+
1422+
// We don't need to add an annotation if there already is one
1423+
if constant.annotation.is_some() {
1424+
return;
1425+
}
1426+
1427+
self.edits.insert(
1428+
constant.name_location.end,
1429+
format!(": {}", self.printer.print_type(&constant.type_)),
1430+
);
1431+
}
1432+
1433+
fn visit_typed_function(&mut self, fun: &'ast ast::TypedFunction) {
1434+
// Since type variable names are local to definitions, any type variables
1435+
// in other parts of the module shouldn't affect what we print for the
1436+
// annotations of this functions. The only variables which cannot clash
1437+
// are ones defined in the signature of this function, which we register
1438+
// when we visit the parameters of this function inside `collect_type_variables`.
1439+
self.printer.clear_type_variables();
1440+
collect_type_variables(&mut self.printer, fun);
1441+
1442+
ast::visit::visit_typed_function(self, fun);
1443+
1444+
let code_action_range = self.edits.src_span_to_lsp_range(
1445+
fun.body_start
1446+
.map(|body_start| SrcSpan {
1447+
start: fun.location.start,
1448+
end: body_start,
1449+
})
1450+
.unwrap_or(fun.location),
1451+
);
1452+
1453+
if overlaps(code_action_range, self.params.range) {
1454+
self.is_hovering_definition = true;
1455+
}
1456+
1457+
// Annotate each argument separately
1458+
for argument in fun.arguments.iter() {
1459+
// Don't annotate the argument if it's already annotated
1460+
if argument.annotation.is_some() {
1461+
continue;
1462+
}
1463+
1464+
self.edits.insert(
1465+
argument.location.end,
1466+
format!(": {}", self.printer.print_type(&argument.type_)),
1467+
);
1468+
}
1469+
1470+
// Annotate the return type if it isn't already annotated
1471+
if fun.return_annotation.is_none() {
1472+
self.edits.insert(
1473+
fun.location.end,
1474+
format!(" -> {}", self.printer.print_type(&fun.return_type)),
1475+
);
1476+
}
1477+
}
1478+
1479+
fn visit_typed_expr_fn(
1480+
&mut self,
1481+
location: &'ast SrcSpan,
1482+
type_: &'ast Arc<Type>,
1483+
kind: &'ast FunctionLiteralKind,
1484+
arguments: &'ast [TypedArg],
1485+
body: &'ast Vec1<TypedStatement>,
1486+
return_annotation: &'ast Option<ast::TypeAst>,
1487+
) {
1488+
ast::visit::visit_typed_expr_fn(
1489+
self,
1490+
location,
1491+
type_,
1492+
kind,
1493+
arguments,
1494+
body,
1495+
return_annotation,
1496+
);
1497+
1498+
// If the function doesn't have a head, we can't annotate it
1499+
let location = match kind {
1500+
// Function captures don't need any type annotations
1501+
FunctionLiteralKind::Capture { .. } => return,
1502+
FunctionLiteralKind::Anonymous { head } => head,
1503+
FunctionLiteralKind::Use { location } => location,
1504+
};
1505+
1506+
let code_action_range = self.edits.src_span_to_lsp_range(*location);
1507+
1508+
if overlaps(code_action_range, self.params.range) {
1509+
self.is_hovering_definition = true;
1510+
}
1511+
1512+
// Annotate each argument separately
1513+
for argument in arguments.iter() {
1514+
// Don't annotate the argument if it's already annotated
1515+
if argument.annotation.is_some() {
1516+
continue;
1517+
}
1518+
1519+
self.edits.insert(
1520+
argument.location.end,
1521+
format!(": {}", self.printer.print_type(&argument.type_)),
1522+
);
1523+
}
1524+
1525+
// Annotate the return type if it isn't already annotated, and this is
1526+
// an anonymous function.
1527+
if return_annotation.is_none() && matches!(kind, FunctionLiteralKind::Anonymous { .. }) {
1528+
let return_type = &type_.return_type().expect("Type must be a function");
1529+
let pretty_type = self.printer.print_type(return_type);
1530+
self.edits
1531+
.insert(location.end, format!(" -> {pretty_type}"));
1532+
}
1533+
}
1534+
}
1535+
13631536
struct TypeVariableCollector<'a, 'b> {
13641537
printer: &'a mut Printer<'b>,
13651538
}

compiler-core/src/language_server/engine.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ use std::{collections::HashSet, sync::Arc};
4444
use super::{
4545
DownloadDependencies, MakeLocker,
4646
code_action::{
47-
AddAnnotations, CodeActionBuilder, ConvertFromUse, ConvertToFunctionCall, ConvertToPipe,
48-
ConvertToUse, ExpandFunctionCapture, ExtractConstant, ExtractVariable,
49-
FillInMissingLabelledArgs, FillUnusedFields, FixBinaryOperation,
47+
AddAnnotations, AnnotateTopLevelTypeDefinitions, CodeActionBuilder, ConvertFromUse,
48+
ConvertToFunctionCall, ConvertToPipe, ConvertToUse, ExpandFunctionCapture, ExtractConstant,
49+
ExtractVariable, FillInMissingLabelledArgs, FillUnusedFields, FixBinaryOperation,
5050
FixTruncatedBitArraySegment, GenerateDynamicDecoder, GenerateFunction, GenerateJsonEncoder,
5151
GenerateVariant, InlineVariable, InterpolateString, LetAssertToCase, PatternMatchOnValue,
5252
RedundantTupleInCaseSubject, RemoveEchos, RemoveUnusedImports, UseLabelShorthandSyntax,
@@ -461,6 +461,9 @@ where
461461
)
462462
.code_actions();
463463
AddAnnotations::new(module, &lines, &params).code_action(&mut actions);
464+
actions.extend(
465+
AnnotateTopLevelTypeDefinitions::new(module, &lines, &params).code_actions(),
466+
);
464467
Ok(if actions.is_empty() {
465468
None
466469
} else {

0 commit comments

Comments
 (0)