diff --git a/internal/ast/utilities.go b/internal/ast/utilities.go index 53e17552be..4968528f9d 100644 --- a/internal/ast/utilities.go +++ b/internal/ast/utilities.go @@ -3865,3 +3865,7 @@ func IsJSDocNameReferenceContext(node *Node) bool { func IsImportOrImportEqualsDeclaration(node *Node) bool { return IsImportDeclaration(node) || IsImportEqualsDeclaration(node) } + +func GetContainingFunction(node *Node) *Node { + return FindAncestor(node.Parent, IsFunctionLike) +} diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 5deba68611..7cbeeb5d34 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -2464,7 +2464,7 @@ func (c *Checker) checkParameter(node *ast.Node) { // or if its FunctionBody is strict code(11.1.5). c.checkGrammarModifiers(node) c.checkVariableLikeDeclaration(node) - fn := getContainingFunction(node) + fn := ast.GetContainingFunction(node) var paramName string if node.Name() != nil && ast.IsIdentifier(node.Name()) { paramName = node.Name().Text() @@ -5522,7 +5522,7 @@ func (c *Checker) checkVariableLikeDeclaration(node *ast.Node) { } if ast.IsBindingElement(node) { propName := node.PropertyName() - if propName != nil && ast.IsIdentifier(node.Name()) && ast.IsPartOfParameterDeclaration(node) && ast.NodeIsMissing(getContainingFunction(node).Body()) { + if propName != nil && ast.IsIdentifier(node.Name()) && ast.IsPartOfParameterDeclaration(node) && ast.NodeIsMissing(ast.GetContainingFunction(node).Body()) { // type F = ({a: string}) => void; // ^^^^^^ // variable renaming in function type notation is confusing, @@ -5557,7 +5557,7 @@ func (c *Checker) checkVariableLikeDeclaration(node *ast.Node) { c.checkSourceElements(name.AsBindingPattern().Elements.Nodes) } // For a parameter declaration with an initializer, error and exit if the containing function doesn't have a body - if initializer != nil && ast.IsPartOfParameterDeclaration(node) && ast.NodeIsMissing(getContainingFunction(node).Body()) { + if initializer != nil && ast.IsPartOfParameterDeclaration(node) && ast.NodeIsMissing(ast.GetContainingFunction(node).Body()) { c.error(node, diagnostics.A_parameter_initializer_is_only_allowed_in_a_function_or_constructor_implementation) return } @@ -10508,7 +10508,7 @@ func (c *Checker) checkSpreadExpression(node *ast.Node, checkMode CheckMode) *Ty func (c *Checker) checkYieldExpression(node *ast.Node) *Type { c.checkGrammarYieldExpression(node) - fn := getContainingFunction(node) + fn := ast.GetContainingFunction(node) if fn == nil { return c.anyType } @@ -22814,7 +22814,7 @@ func (c *Checker) getTypeAliasInstantiation(symbol *ast.Symbol, typeArguments [] func isLocalTypeAlias(symbol *ast.Symbol) bool { declaration := core.Find(symbol.Declarations, isTypeAlias) - return declaration != nil && getContainingFunction(declaration) != nil + return declaration != nil && ast.GetContainingFunction(declaration) != nil } func (c *Checker) getDeclaredTypeOfSymbol(symbol *ast.Symbol) *Type { @@ -28318,7 +28318,7 @@ func (c *Checker) getContextualTypeForStaticPropertyDeclaration(declaration *ast } func (c *Checker) getContextualTypeForReturnExpression(node *ast.Node, contextFlags ContextFlags) *Type { - fn := getContainingFunction(node) + fn := ast.GetContainingFunction(node) if fn != nil { contextualReturnType := c.getContextualReturnType(fn, contextFlags) if contextualReturnType != nil { @@ -28413,7 +28413,7 @@ func (c *Checker) getContextualSignatureForFunctionLikeDeclaration(node *ast.Nod } func (c *Checker) getContextualTypeForYieldOperand(node *ast.Node, contextFlags ContextFlags) *Type { - fn := getContainingFunction(node) + fn := ast.GetContainingFunction(node) if fn != nil { functionFlags := getFunctionFlags(fn) contextualReturnType := c.getContextualReturnType(fn, contextFlags) diff --git a/internal/checker/grammarchecks.go b/internal/checker/grammarchecks.go index 75943613bf..1b798ff413 100644 --- a/internal/checker/grammarchecks.go +++ b/internal/checker/grammarchecks.go @@ -1263,7 +1263,7 @@ func (c *Checker) checkGrammarForInOrForOfStatement(forInOrOfStatement *ast.ForI // use of 'for-await-of' in non-async function if !c.hasParseDiagnostics(sourceFile) { diagnostic := createDiagnosticForNode(forInOrOfStatement.AwaitModifier, diagnostics.X_for_await_loops_are_only_allowed_within_async_functions_and_at_the_top_levels_of_modules) - containingFunc := getContainingFunction(forInOrOfStatement.AsNode()) + containingFunc := ast.GetContainingFunction(forInOrOfStatement.AsNode()) if containingFunc != nil && containingFunc.Kind != ast.KindConstructor { debug.Assert((getFunctionFlags(containingFunc)&FunctionFlagsAsync) == 0, "Enclosing function should never be an async function.") if hasAsyncModifier(containingFunc) { diff --git a/internal/checker/utilities.go b/internal/checker/utilities.go index e131201893..3637b8d421 100644 --- a/internal/checker/utilities.go +++ b/internal/checker/utilities.go @@ -911,10 +911,6 @@ func (s *orderedSet[T]) add(value T) { s.values = append(s.values, value) } -func getContainingFunction(node *ast.Node) *ast.Node { - return ast.FindAncestor(node.Parent, ast.IsFunctionLike) -} - func getContainingFunctionOrClassStaticBlock(node *ast.Node) *ast.Node { return ast.FindAncestor(node.Parent, ast.IsFunctionLikeOrClassStaticBlockDeclaration) } diff --git a/internal/ls/documenthighlights.go b/internal/ls/documenthighlights.go index 8469c0935d..f0ed4ce0e0 100644 --- a/internal/ls/documenthighlights.go +++ b/internal/ls/documenthighlights.go @@ -550,14 +550,14 @@ func getLoopBreakContinueOccurrences(node *ast.Node, sourceFile *ast.SourceFile) } func getAsyncAndAwaitOccurrences(node *ast.Node, sourceFile *ast.SourceFile) []*ast.Node { - parent := ast.FindAncestor(node.Parent, ast.IsFunctionLike) - if parent == nil { + fun := ast.GetContainingFunction(node) + if fun == nil { return nil } - parentFunc := parent.AsFunctionDeclaration() + var keywords []*ast.Node - modifiers := parentFunc.Modifiers() + modifiers := fun.Modifiers() if modifiers != nil { for _, modifier := range modifiers.Nodes { if modifier.Kind == ast.KindAsyncKeyword { @@ -566,7 +566,7 @@ func getAsyncAndAwaitOccurrences(node *ast.Node, sourceFile *ast.SourceFile) []* } } - parentFunc.ForEachChild(func(child *ast.Node) bool { + fun.ForEachChild(func(child *ast.Node) bool { traverseWithoutCrossingFunction(child, sourceFile, func(child *ast.Node) { if ast.IsAwaitExpression(child) { token := lsutil.GetFirstToken(child, sourceFile)