Skip to content

Commit bcb8510

Browse files
Implement rename and improve find-all-references in LSP (#1635)
Co-authored-by: Daniel Rosenwasser <[email protected]>
1 parent d85436e commit bcb8510

File tree

49 files changed

+1997
-310
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1997
-310
lines changed

internal/ast/ast.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,16 @@ func (n *Node) ModuleSpecifier() *Expression {
998998
panic("Unhandled case in Node.ModuleSpecifier: " + n.Kind.String())
999999
}
10001000

1001+
func (n *Node) ImportClause() *Node {
1002+
switch n.Kind {
1003+
case KindImportDeclaration, KindJSImportDeclaration:
1004+
return n.AsImportDeclaration().ImportClause
1005+
case KindJSDocImportTag:
1006+
return n.AsJSDocImportTag().ImportClause
1007+
}
1008+
panic("Unhandled case in Node.ImportClause: " + n.Kind.String())
1009+
}
1010+
10011011
func (n *Node) Statement() *Statement {
10021012
switch n.Kind {
10031013
case KindDoStatement:
@@ -10117,6 +10127,10 @@ func (node *JSDocCallbackTag) Clone(f NodeFactoryCoercible) *Node {
1011710127
return cloneNode(f.AsNodeFactory().NewJSDocCallbackTag(node.TagName, node.TypeExpression, node.FullName, node.Comment), node.AsNode(), f.AsNodeFactory().hooks)
1011810128
}
1011910129

10130+
func IsJSDocCallbackTag(node *Node) bool {
10131+
return node.Kind == KindJSDocCallbackTag
10132+
}
10133+
1012010134
// JSDocOverloadTag
1012110135
type JSDocOverloadTag struct {
1012210136
JSDocTagBase
@@ -10190,6 +10204,10 @@ func (node *JSDocTypedefTag) Clone(f NodeFactoryCoercible) *Node {
1019010204

1019110205
func (node *JSDocTypedefTag) Name() *DeclarationName { return node.name }
1019210206

10207+
func IsJSDocTypedefTag(node *Node) bool {
10208+
return node.Kind == KindJSDocTypedefTag
10209+
}
10210+
1019310211
// JSDocTypeLiteral
1019410212
type JSDocTypeLiteral struct {
1019510213
TypeNodeBase

internal/checker/exports.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,7 @@ func (c *Checker) GetIndexSignaturesAtLocation(node *ast.Node) []*ast.Node {
150150
func (c *Checker) GetResolvedSymbol(node *ast.Node) *ast.Symbol {
151151
return c.getResolvedSymbol(node)
152152
}
153+
154+
func (c *Checker) GetImmediateAliasedSymbol(symbol *ast.Symbol) *ast.Symbol {
155+
return c.getImmediateAliasedSymbol(symbol)
156+
}

internal/checker/services.go

Lines changed: 132 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,14 @@ func runWithoutResolvedSignatureCaching[T any](c *Checker, node *ast.Node, fn fu
347347

348348
func (c *Checker) GetRootSymbols(symbol *ast.Symbol) []*ast.Symbol {
349349
roots := c.getImmediateRootSymbols(symbol)
350-
if roots != nil {
351-
var result []*ast.Symbol
352-
for _, root := range roots {
353-
result = append(result, c.GetRootSymbols(root)...)
354-
}
355-
return result
350+
if len(roots) == 0 {
351+
return []*ast.Symbol{symbol}
356352
}
357-
return []*ast.Symbol{symbol}
353+
var result []*ast.Symbol
354+
for _, root := range roots {
355+
result = append(result, c.GetRootSymbols(root)...)
356+
}
357+
return result
358358
}
359359

360360
func (c *Checker) getImmediateRootSymbols(symbol *ast.Symbol) []*ast.Symbol {
@@ -364,7 +364,8 @@ func (c *Checker) getImmediateRootSymbols(symbol *ast.Symbol) []*ast.Symbol {
364364
func(t *Type) *ast.Symbol {
365365
return c.getPropertyOfType(t, symbol.Name)
366366
})
367-
} else if symbol.Flags&ast.SymbolFlagsTransient != 0 {
367+
}
368+
if symbol.Flags&ast.SymbolFlagsTransient != 0 {
368369
if c.spreadLinks.Has(symbol) {
369370
leftSpread := c.spreadLinks.Get(symbol).leftSpread
370371
rightSpread := c.spreadLinks.Get(symbol).rightSpread
@@ -382,9 +383,7 @@ func (c *Checker) getImmediateRootSymbols(symbol *ast.Symbol) []*ast.Symbol {
382383
if target != nil {
383384
return []*ast.Symbol{target}
384385
}
385-
return nil
386386
}
387-
388387
return nil
389388
}
390389

@@ -418,20 +417,16 @@ func (c *Checker) GetExportSpecifierLocalTargetSymbol(node *ast.Node) *ast.Symbo
418417
if node.Parent.Parent.AsExportDeclaration().ModuleSpecifier != nil {
419418
return c.getExternalModuleMember(node.Parent.Parent, node, false /*dontResolveAlias*/)
420419
}
421-
name := node.PropertyName()
422-
if name == nil {
423-
name = node.Name()
424-
}
420+
name := node.PropertyNameOrName()
425421
if name.Kind == ast.KindStringLiteral {
426422
// Skip for invalid syntax like this: export { "x" }
427423
return nil
428424
}
425+
return c.resolveEntityName(name, ast.SymbolFlagsValue|ast.SymbolFlagsType|ast.SymbolFlagsNamespace|ast.SymbolFlagsAlias, true /*ignoreErrors*/, false, nil)
429426
case ast.KindIdentifier:
430-
// do nothing (don't panic)
431-
default:
432-
panic("Unhandled case in getExportSpecifierLocalTargetSymbol, node should be ExportSpecifier | Identifier")
427+
return c.resolveEntityName(node, ast.SymbolFlagsValue|ast.SymbolFlagsType|ast.SymbolFlagsNamespace|ast.SymbolFlagsAlias, true /*ignoreErrors*/, false, nil)
433428
}
434-
return c.resolveEntityName(node, ast.SymbolFlagsValue|ast.SymbolFlagsType|ast.SymbolFlagsNamespace|ast.SymbolFlagsAlias, true /*ignoreErrors*/, false, nil)
429+
panic("Unhandled case in getExportSpecifierLocalTargetSymbol, node should be ExportSpecifier | Identifier")
435430
}
436431

437432
func (c *Checker) GetShorthandAssignmentValueSymbol(location *ast.Node) *ast.Symbol {
@@ -683,3 +678,122 @@ func (c *Checker) GetFirstTypeArgumentFromKnownType(t *Type) *Type {
683678
}
684679
return nil
685680
}
681+
682+
// Gets all symbols for one property. Does not get symbols for every property.
683+
func (c *Checker) GetPropertySymbolsFromContextualType(node *ast.Node, contextualType *Type, unionSymbolOk bool) []*ast.Symbol {
684+
name := ast.GetTextOfPropertyName(node.Name())
685+
if name == "" {
686+
return nil
687+
}
688+
if contextualType.flags&TypeFlagsUnion == 0 {
689+
if symbol := c.getPropertyOfType(contextualType, name); symbol != nil {
690+
return []*ast.Symbol{symbol}
691+
}
692+
return nil
693+
}
694+
filteredTypes := contextualType.Types()
695+
if ast.IsObjectLiteralExpression(node.Parent) || ast.IsJsxAttributes(node.Parent) {
696+
filteredTypes = core.Filter(filteredTypes, func(t *Type) bool {
697+
return !c.IsTypeInvalidDueToUnionDiscriminant(t, node.Parent)
698+
})
699+
}
700+
discriminatedPropertySymbols := core.MapNonNil(filteredTypes, func(t *Type) *ast.Symbol {
701+
return c.getPropertyOfType(t, name)
702+
})
703+
if unionSymbolOk && (len(discriminatedPropertySymbols) == 0 || len(discriminatedPropertySymbols) == len(contextualType.Types())) {
704+
if symbol := c.getPropertyOfType(contextualType, name); symbol != nil {
705+
return []*ast.Symbol{symbol}
706+
}
707+
}
708+
if len(filteredTypes) == 0 && len(discriminatedPropertySymbols) == 0 {
709+
// Bad discriminant -- do again without discriminating
710+
return core.MapNonNil(contextualType.Types(), func(t *Type) *ast.Symbol {
711+
return c.getPropertyOfType(t, name)
712+
})
713+
}
714+
// by eliminating duplicates we might even end up with a single symbol
715+
// that helps with displaying better quick infos on properties of union types
716+
return core.Deduplicate(discriminatedPropertySymbols)
717+
}
718+
719+
// Gets the property symbol corresponding to the property in destructuring assignment
720+
// 'property1' from
721+
//
722+
// for ( { property1: a } of elems) {
723+
// }
724+
//
725+
// 'property1' at location 'a' from:
726+
//
727+
// [a] = [ property1, property2 ]
728+
func (c *Checker) GetPropertySymbolOfDestructuringAssignment(location *ast.Node) *ast.Symbol {
729+
if isArrayLiteralOrObjectLiteralDestructuringPattern(location.Parent.Parent) {
730+
// Get the type of the object or array literal and then look for property of given name in the type
731+
if typeOfObjectLiteral := c.getTypeOfAssignmentPattern(location.Parent.Parent); typeOfObjectLiteral != nil {
732+
return c.getPropertyOfType(typeOfObjectLiteral, location.Text())
733+
}
734+
}
735+
return nil
736+
}
737+
738+
// Gets the type of object literal or array literal of destructuring assignment.
739+
// { a } from
740+
//
741+
// for ( { a } of elems) {
742+
// }
743+
//
744+
// [ a ] from
745+
//
746+
// [a] = [ some array ...]
747+
func (c *Checker) getTypeOfAssignmentPattern(expr *ast.Node) *Type {
748+
// If this is from "for of"
749+
// for ( { a } of elems) {
750+
// }
751+
if ast.IsForOfStatement(expr.Parent) {
752+
iteratedType := c.checkRightHandSideOfForOf(expr.Parent)
753+
return c.checkDestructuringAssignment(expr, core.OrElse(iteratedType, c.errorType), CheckModeNormal, false)
754+
}
755+
// If this is from "for" initializer
756+
// for ({a } = elems[0];.....) { }
757+
if ast.IsBinaryExpression(expr.Parent) {
758+
iteratedType := c.getTypeOfExpression(expr.Parent.AsBinaryExpression().Right)
759+
return c.checkDestructuringAssignment(expr, core.OrElse(iteratedType, c.errorType), CheckModeNormal, false)
760+
}
761+
// If this is from nested object binding pattern
762+
// for ({ skills: { primary, secondary } } = multiRobot, i = 0; i < 1; i++) {
763+
if ast.IsPropertyAssignment(expr.Parent) {
764+
node := expr.Parent.Parent
765+
typeOfParentObjectLiteral := core.OrElse(c.getTypeOfAssignmentPattern(node), c.errorType)
766+
propertyIndex := slices.Index(node.AsObjectLiteralExpression().Properties.Nodes, expr.Parent)
767+
return c.checkObjectLiteralDestructuringPropertyAssignment(node, typeOfParentObjectLiteral, propertyIndex, nil, false)
768+
}
769+
// Array literal assignment - array destructuring pattern
770+
node := expr.Parent
771+
// [{ property1: p1, property2 }] = elems;
772+
typeOfArrayLiteral := core.OrElse(c.getTypeOfAssignmentPattern(node), c.errorType)
773+
elementType := core.OrElse(c.checkIteratedTypeOrElementType(IterationUseDestructuring, typeOfArrayLiteral, c.undefinedType, expr.Parent), c.errorType)
774+
return c.checkArrayLiteralDestructuringElementAssignment(node, typeOfArrayLiteral, slices.Index(node.AsArrayLiteralExpression().Elements.Nodes, expr), elementType, CheckModeNormal)
775+
}
776+
777+
func isArrayLiteralOrObjectLiteralDestructuringPattern(node *ast.Node) bool {
778+
if !(ast.IsArrayLiteralExpression(node) || ast.IsObjectLiteralExpression(node)) {
779+
return false
780+
}
781+
parent := node.Parent
782+
// [a,b,c] from:
783+
// [a, b, c] = someExpression;
784+
if ast.IsBinaryExpression(parent) && parent.AsBinaryExpression().Left == node && parent.AsBinaryExpression().OperatorToken.Kind == ast.KindEqualsToken {
785+
return true
786+
}
787+
// [a, b, c] from:
788+
// for([a, b, c] of expression)
789+
if ast.IsForOfStatement(parent) && parent.Initializer() == node {
790+
return true
791+
}
792+
// {x, a: {a, b, c} } = someExpression
793+
if ast.IsPropertyAssignment(parent) {
794+
return isArrayLiteralOrObjectLiteralDestructuringPattern(parent.Parent)
795+
}
796+
// [a, b, c] of
797+
// [x, [a, b, c] ] = someExpression
798+
return isArrayLiteralOrObjectLiteralDestructuringPattern(parent)
799+
}

internal/core/core.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,12 +351,6 @@ func Coalesce[T *U, U any](a T, b T) T {
351351
}
352352
}
353353

354-
// Returns the first element that is not `nil`; CoalesceList(a, b, c) is roughly analogous to `a ?? b ?? c` in JS, except that it
355-
// non-shortcutting, so it is advised to only use a constant or precomputed value for non-first values in the list
356-
func CoalesceList[T *U, U any](a ...T) T {
357-
return FirstNonNil(a, func(t T) T { return t })
358-
}
359-
360354
func ComputeLineStarts(text string) []TextPos {
361355
result := make([]TextPos, 0, strings.Count(text, "\n")+1)
362356
return slices.AppendSeq(result, ComputeLineStartsSeq(text))
@@ -625,3 +619,21 @@ func CopyMapInto[M1 ~map[K]V, M2 ~map[K]V, K comparable, V any](dst M1, src M2)
625619
maps.Copy(dst, src)
626620
return dst
627621
}
622+
623+
func Deduplicate[T comparable](slice []T) []T {
624+
if len(slice) > 1 {
625+
for i, value := range slice {
626+
if slices.Contains(slice[:i], value) {
627+
result := slices.Clone(slice[:i])
628+
for i++; i < len(slice); i++ {
629+
value = slice[i]
630+
if !slices.Contains(result, value) {
631+
result = append(result, value)
632+
}
633+
}
634+
return result
635+
}
636+
}
637+
}
638+
return slice
639+
}

0 commit comments

Comments
 (0)