|
9 | 9 | using System.Threading; |
10 | 10 | using System.Threading.Tasks; |
11 | 11 | using Microsoft.CodeAnalysis.LanguageService; |
| 12 | +using Microsoft.CodeAnalysis.PooledObjects; |
12 | 13 | using Microsoft.CodeAnalysis.Shared.Extensions; |
13 | 14 | using Microsoft.CodeAnalysis.Text; |
14 | 15 | using Roslyn.Utilities; |
@@ -45,7 +46,7 @@ internal abstract class SelectionResult( |
45 | 46 | /// </summary> |
46 | 47 | private ControlFlowAnalysis? _statementControlFlowAnalysis; |
47 | 48 |
|
48 | | - protected abstract bool UnderAnonymousOrLocalMethod(SyntaxToken token, SyntaxToken firstToken, SyntaxToken lastToken); |
| 49 | + // protected abstract bool UnderAnonymousOrLocalMethod(SyntaxToken token, SyntaxToken firstToken, SyntaxToken lastToken); |
49 | 50 |
|
50 | 51 | public abstract TExecutableStatementSyntax GetFirstStatementUnderContainer(); |
51 | 52 | public abstract TExecutableStatementSyntax GetLastStatementUnderContainer(); |
@@ -119,31 +120,36 @@ public TExecutableStatementSyntax GetLastStatement() |
119 | 120 | return token.GetRequiredAncestor<TExecutableStatementSyntax>(); |
120 | 121 | } |
121 | 122 |
|
122 | | - public bool CreateAsyncMethod() |
| 123 | + public bool ContainsAwaitExpression() |
123 | 124 | { |
124 | 125 | _createAsyncMethod ??= CreateAsyncMethodWorker(); |
125 | 126 | return _createAsyncMethod.Value; |
126 | 127 |
|
127 | 128 | bool CreateAsyncMethodWorker() |
128 | 129 | { |
129 | | - var firstToken = GetFirstTokenInSelection(); |
130 | | - var lastToken = GetLastTokenInSelection(); |
131 | | - var syntaxFacts = SemanticDocument.GetRequiredLanguageService<ISyntaxFactsService>(); |
| 130 | + var firstToken = this.GetFirstTokenInSelection(); |
| 131 | + var lastToken = this.GetLastTokenInSelection(); |
| 132 | + var span = TextSpan.FromBounds(firstToken.SpanStart, lastToken.Span.End); |
132 | 133 |
|
133 | | - for (var currentToken = firstToken; |
134 | | - currentToken.Span.End < lastToken.SpanStart; |
135 | | - currentToken = currentToken.GetNextToken()) |
| 134 | + using var _ = ArrayBuilder<SyntaxNode>.GetInstance(out var stack); |
| 135 | + stack.Push(this.GetContainingScope()); |
| 136 | + |
| 137 | + var syntaxFacts = this.SemanticDocument.GetRequiredLanguageService<ISyntaxFactsService>(); |
| 138 | + |
| 139 | + while (stack.TryPop(out var current)) |
136 | 140 | { |
137 | | - // [| |
138 | | - // async () => await .... |
139 | | - // |] |
140 | | - // |
141 | | - // for the case above, even if the selection contains "await", it doesn't belong to the enclosing block |
142 | | - // which extract method is applied to |
143 | | - if (syntaxFacts.IsAwaitKeyword(currentToken) |
144 | | - && !UnderAnonymousOrLocalMethod(currentToken, firstToken, lastToken)) |
145 | | - { |
| 141 | + // Don't dive into lambdas and local functions. They reset the async/await context. |
| 142 | + if (syntaxFacts.IsAnonymousOrLocalFunction(current)) |
| 143 | + continue; |
| 144 | + |
| 145 | + if (syntaxFacts.IsAwaitExpression(current)) |
146 | 146 | return true; |
| 147 | + |
| 148 | + // Only dive into child nodes within the span being extracted. |
| 149 | + foreach (var childNode in current.ChildNodes()) |
| 150 | + { |
| 151 | + if (childNode.Span.OverlapsWith(span)) |
| 152 | + stack.Push(childNode); |
147 | 153 | } |
148 | 154 | } |
149 | 155 |
|
|
0 commit comments