@@ -1360,6 +1360,179 @@ impl<'a> AddAnnotations<'a> {
1360
1360
}
1361
1361
}
1362
1362
1363
+ /// Code action to add type annotations to all top level definitions
1364
+ ///
1365
+ pub struct AnnotateTopLevelTypeDefinitions < ' a > {
1366
+ module : & ' a Module ,
1367
+ params : & ' a CodeActionParams ,
1368
+ edits : TextEdits < ' a > ,
1369
+ printer : Printer < ' a > ,
1370
+ is_hovering_definition : bool ,
1371
+ }
1372
+
1373
+ impl < ' a > AnnotateTopLevelTypeDefinitions < ' a > {
1374
+ pub fn new (
1375
+ module : & ' a Module ,
1376
+ line_numbers : & ' a LineNumbers ,
1377
+ params : & ' a CodeActionParams ,
1378
+ ) -> Self {
1379
+ Self {
1380
+ module,
1381
+ params,
1382
+ edits : TextEdits :: new ( line_numbers) ,
1383
+ // We need to use the same printer for all the edits because otherwise
1384
+ // we could get duplicate type variable names.
1385
+ printer : Printer :: new_without_type_variables ( & module. ast . names ) ,
1386
+ is_hovering_definition : false ,
1387
+ }
1388
+ }
1389
+
1390
+ pub fn code_actions ( mut self ) -> Vec < CodeAction > {
1391
+ self . visit_typed_module ( & self . module . ast ) ;
1392
+
1393
+ // We only want to trigger the action if we're over one of the definition in
1394
+ // the module
1395
+ if !self . is_hovering_definition {
1396
+ return vec ! [ ] ;
1397
+ } ;
1398
+
1399
+ let mut action = Vec :: with_capacity ( 1 ) ;
1400
+ CodeActionBuilder :: new ( "Annotate all top level type definitions" )
1401
+ . kind ( CodeActionKind :: REFACTOR_REWRITE )
1402
+ . changes ( self . params . text_document . uri . clone ( ) , self . edits . edits )
1403
+ . preferred ( false )
1404
+ . push_to ( & mut action) ;
1405
+ action
1406
+ }
1407
+ }
1408
+
1409
+ impl < ' ast > ast:: visit:: Visit < ' ast > for AnnotateTopLevelTypeDefinitions < ' _ > {
1410
+ fn visit_typed_module_constant ( & mut self , constant : & ' ast TypedModuleConstant ) {
1411
+ // Since type variable names are local to definitions, any type variables
1412
+ // in other parts of the module shouldn't affect what we print for the
1413
+ // annotations of this constant.
1414
+ self . printer . clear_type_variables ( ) ;
1415
+
1416
+ let code_action_range = self . edits . src_span_to_lsp_range ( constant. location ) ;
1417
+
1418
+ if overlaps ( code_action_range, self . params . range ) {
1419
+ self . is_hovering_definition = true ;
1420
+ }
1421
+
1422
+ // We don't need to add an annotation if there already is one
1423
+ if constant. annotation . is_some ( ) {
1424
+ return ;
1425
+ }
1426
+
1427
+ self . edits . insert (
1428
+ constant. name_location . end ,
1429
+ format ! ( ": {}" , self . printer. print_type( & constant. type_) ) ,
1430
+ ) ;
1431
+ }
1432
+
1433
+ fn visit_typed_function ( & mut self , fun : & ' ast ast:: TypedFunction ) {
1434
+ // Since type variable names are local to definitions, any type variables
1435
+ // in other parts of the module shouldn't affect what we print for the
1436
+ // annotations of this functions. The only variables which cannot clash
1437
+ // are ones defined in the signature of this function, which we register
1438
+ // when we visit the parameters of this function inside `collect_type_variables`.
1439
+ self . printer . clear_type_variables ( ) ;
1440
+ collect_type_variables ( & mut self . printer , fun) ;
1441
+
1442
+ ast:: visit:: visit_typed_function ( self , fun) ;
1443
+
1444
+ let code_action_range = self . edits . src_span_to_lsp_range (
1445
+ fun. body_start
1446
+ . map ( |body_start| SrcSpan {
1447
+ start : fun. location . start ,
1448
+ end : body_start,
1449
+ } )
1450
+ . unwrap_or ( fun. location ) ,
1451
+ ) ;
1452
+
1453
+ if overlaps ( code_action_range, self . params . range ) {
1454
+ self . is_hovering_definition = true ;
1455
+ }
1456
+
1457
+ // Annotate each argument separately
1458
+ for argument in fun. arguments . iter ( ) {
1459
+ // Don't annotate the argument if it's already annotated
1460
+ if argument. annotation . is_some ( ) {
1461
+ continue ;
1462
+ }
1463
+
1464
+ self . edits . insert (
1465
+ argument. location . end ,
1466
+ format ! ( ": {}" , self . printer. print_type( & argument. type_) ) ,
1467
+ ) ;
1468
+ }
1469
+
1470
+ // Annotate the return type if it isn't already annotated
1471
+ if fun. return_annotation . is_none ( ) {
1472
+ self . edits . insert (
1473
+ fun. location . end ,
1474
+ format ! ( " -> {}" , self . printer. print_type( & fun. return_type) ) ,
1475
+ ) ;
1476
+ }
1477
+ }
1478
+
1479
+ fn visit_typed_expr_fn (
1480
+ & mut self ,
1481
+ location : & ' ast SrcSpan ,
1482
+ type_ : & ' ast Arc < Type > ,
1483
+ kind : & ' ast FunctionLiteralKind ,
1484
+ arguments : & ' ast [ TypedArg ] ,
1485
+ body : & ' ast Vec1 < TypedStatement > ,
1486
+ return_annotation : & ' ast Option < ast:: TypeAst > ,
1487
+ ) {
1488
+ ast:: visit:: visit_typed_expr_fn (
1489
+ self ,
1490
+ location,
1491
+ type_,
1492
+ kind,
1493
+ arguments,
1494
+ body,
1495
+ return_annotation,
1496
+ ) ;
1497
+
1498
+ // If the function doesn't have a head, we can't annotate it
1499
+ let location = match kind {
1500
+ // Function captures don't need any type annotations
1501
+ FunctionLiteralKind :: Capture { .. } => return ,
1502
+ FunctionLiteralKind :: Anonymous { head } => head,
1503
+ FunctionLiteralKind :: Use { location } => location,
1504
+ } ;
1505
+
1506
+ let code_action_range = self . edits . src_span_to_lsp_range ( * location) ;
1507
+
1508
+ if overlaps ( code_action_range, self . params . range ) {
1509
+ self . is_hovering_definition = true ;
1510
+ }
1511
+
1512
+ // Annotate each argument separately
1513
+ for argument in arguments. iter ( ) {
1514
+ // Don't annotate the argument if it's already annotated
1515
+ if argument. annotation . is_some ( ) {
1516
+ continue ;
1517
+ }
1518
+
1519
+ self . edits . insert (
1520
+ argument. location . end ,
1521
+ format ! ( ": {}" , self . printer. print_type( & argument. type_) ) ,
1522
+ ) ;
1523
+ }
1524
+
1525
+ // Annotate the return type if it isn't already annotated, and this is
1526
+ // an anonymous function.
1527
+ if return_annotation. is_none ( ) && matches ! ( kind, FunctionLiteralKind :: Anonymous { .. } ) {
1528
+ let return_type = & type_. return_type ( ) . expect ( "Type must be a function" ) ;
1529
+ let pretty_type = self . printer . print_type ( return_type) ;
1530
+ self . edits
1531
+ . insert ( location. end , format ! ( " -> {pretty_type}" ) ) ;
1532
+ }
1533
+ }
1534
+ }
1535
+
1363
1536
struct TypeVariableCollector < ' a , ' b > {
1364
1537
printer : & ' a mut Printer < ' b > ,
1365
1538
}
0 commit comments