Skip to content

Commit 4951fd3

Browse files
IN progress
1 parent db1fa67 commit 4951fd3

File tree

8 files changed

+63
-79
lines changed

8 files changed

+63
-79
lines changed

src/Features/CSharp/Portable/ExtractMethod/CSharpMethodExtractor.CSharpCodeGenerator.cs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ private bool ShouldPutUnsafeModifier()
209209
private DeclarationModifiers CreateMethodModifiers()
210210
{
211211
var isUnsafe = ShouldPutUnsafeModifier();
212-
var isAsync = this.SelectionResult.CreateAsyncMethod();
212+
var isAsync = this.SelectionResult.ContainsAwaitExpression();
213213
var isStatic = !AnalyzerResult.UseInstanceMember;
214214
var isReadOnly = AnalyzerResult.ShouldBeReadOnly;
215215

@@ -604,21 +604,20 @@ protected override ExpressionSyntax CreateCallSignature()
604604
}
605605

606606
var invocation = (ExpressionSyntax)InvocationExpression(methodExpression, ArgumentList([.. arguments]));
607+
608+
// If we're extracting any code that contained an 'await' then we'll have to await the new method we're
609+
// calling as well. If we also see any use of .ConfigureAwait(false) in the extracted code, keep that
610+
// pattern on the await expression we produce.
607611
if (this.SelectionResult.ContainsAwaitExpression())
608612
{
609613
if (this.SelectionResult.ContainsConfigureAwaitFalse())
610614
{
611-
if (this.GetFinalReturnType()
612-
.GetMembers(nameof(Task.ConfigureAwait))
613-
.Any(static x => x is IMethodSymbol { Parameters: [{ Type.SpecialType: SpecialType.System_Boolean }] }))
614-
{
615-
invocation = InvocationExpression(
616-
MemberAccessExpression(
617-
SyntaxKind.SimpleMemberAccessExpression,
618-
invocation,
619-
IdentifierName(nameof(Task.ConfigureAwait))),
620-
ArgumentList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))]));
621-
}
615+
invocation = InvocationExpression(
616+
MemberAccessExpression(
617+
SyntaxKind.SimpleMemberAccessExpression,
618+
invocation,
619+
IdentifierName(nameof(Task.ConfigureAwait))),
620+
ArgumentList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))]));
622621
}
623622

624623
invocation = AwaitExpression(invocation);

src/Features/CSharp/Portable/ExtractMethod/CSharpSelectionResult.StatementResult.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,13 @@ private sealed class StatementResult(
2727
: CSharpSelectionResult(document, selectionType, finalSpan)
2828
{
2929
public override bool ContainingScopeHasAsyncKeyword()
30-
{
31-
var node = GetContainingScope();
32-
33-
return node switch
30+
=> GetContainingScope() switch
3431
{
3532
MethodDeclarationSyntax method => method.Modifiers.Any(SyntaxKind.AsyncKeyword),
3633
LocalFunctionStatementSyntax localFunction => localFunction.Modifiers.Any(SyntaxKind.AsyncKeyword),
3734
AnonymousFunctionExpressionSyntax anonymousFunction => anonymousFunction.AsyncKeyword != default,
3835
_ => false,
3936
};
40-
}
4137

4238
public override SyntaxNode GetContainingScope()
4339
{

src/Features/CSharp/Portable/ExtractMethod/CSharpSelectionResult.cs

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
using System.Threading.Tasks;
1111
using Microsoft.CodeAnalysis;
1212
using Microsoft.CodeAnalysis.CSharp.Extensions;
13-
using Microsoft.CodeAnalysis.CSharp.LanguageService;
1413
using Microsoft.CodeAnalysis.CSharp.Syntax;
1514
using Microsoft.CodeAnalysis.ExtractMethod;
16-
using Microsoft.CodeAnalysis.LanguageService;
1715
using Microsoft.CodeAnalysis.Shared.Extensions;
1816
using Microsoft.CodeAnalysis.Text;
1917
using Roslyn.Utilities;
@@ -68,27 +66,6 @@ protected override SyntaxNode GetNodeForDataFlowAnalysis()
6866
: node;
6967
}
7068

71-
protected override bool UnderAnonymousOrLocalMethod(SyntaxToken token, SyntaxToken firstToken, SyntaxToken lastToken)
72-
=> IsUnderAnonymousOrLocalMethod(token, firstToken, lastToken);
73-
74-
public static bool IsUnderAnonymousOrLocalMethod(SyntaxToken token, SyntaxToken firstToken, SyntaxToken lastToken)
75-
{
76-
for (var current = token.Parent; current != null; current = current.Parent)
77-
{
78-
if (current is MemberDeclarationSyntax)
79-
return false;
80-
81-
if (current is AnonymousFunctionExpressionSyntax or LocalFunctionStatementSyntax)
82-
{
83-
// make sure the selection contains the lambda
84-
return firstToken.SpanStart <= current.GetFirstToken().SpanStart &&
85-
current.GetLastToken().Span.End <= lastToken.Span.End;
86-
}
87-
}
88-
89-
return false;
90-
}
91-
9269
public override StatementSyntax GetFirstStatementUnderContainer()
9370
{
9471
Contract.ThrowIfTrue(IsExtractMethodOnExpression);

src/Features/Core/Portable/ExtractMethod/MethodExtractor.Analyzer.cs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ protected virtual bool IsReadOutside(ISymbol symbol, HashSet<ISymbol> readOutsid
7171
public AnalyzerResult Analyze()
7272
{
7373
// do data flow analysis
74-
var model = this.SemanticDocument.SemanticModel;
74+
var model = this.SemanticModel;
7575
var dataFlowAnalysisData = this.SelectionResult.GetDataFlowAnalysis();
7676

7777
// build symbol map for the identifiers used inside of the selection
@@ -158,7 +158,7 @@ public AnalyzerResult Analyze()
158158
// check whether current selection contains return statement
159159
var (returnType, returnsByRef) = SelectionResult.GetReturnTypeInfo(this.CancellationToken);
160160

161-
return (allVariableInfos, returnType, returnsByRef);
161+
return (allVariableInfos, UnwrapTaskIfNeeded(returnType), returnsByRef);
162162
}
163163
else
164164
{
@@ -171,9 +171,34 @@ public AnalyzerResult Analyze()
171171
return (finalOrderedVariableInfos, returnType, returnsByRef: false);
172172
}
173173

174+
ITypeSymbol UnwrapTaskIfNeeded(ITypeSymbol returnType)
175+
{
176+
if (this.SelectionResult.ContainingScopeHasAsyncKeyword())
177+
{
178+
// We compute the desired return type for the extract method from our own return type. But for
179+
// the purposes of manipulating the return type, we need to get to the underlying type if this
180+
// was wrapped in a Task in an explicitly 'async' method. In other words, if we're in an `async
181+
// Task<int>` method, then we want the extract method to return `int`. Note: we will possibly
182+
// then wrap that as `Task<int>` again if we see that we extracted out any await-expressions.
183+
184+
var compilation = this.SemanticModel.Compilation;
185+
var knownTaskTypes = new KnownTaskTypes(compilation);
186+
187+
// Map from `Task/ValueTask` to `void`
188+
if (returnType.Equals(knownTaskTypes.TaskType) || returnType.Equals(knownTaskTypes.ValueTaskType))
189+
return compilation.GetSpecialType(SpecialType.System_Void);
190+
191+
// Map from `Task<T>/ValueTask<T>` to `T`
192+
if (returnType.OriginalDefinition.Equals(knownTaskTypes.TaskOfTType) || returnType.OriginalDefinition.Equals(knownTaskTypes.ValueTaskOfTType))
193+
return returnType.GetTypeArguments().Single();
194+
}
195+
196+
return returnType;
197+
}
198+
174199
ITypeSymbol GetReturnType(ImmutableArray<VariableInfo> variablesToUseAsReturnValue)
175200
{
176-
var compilation = this.SemanticDocument.SemanticModel.Compilation;
201+
var compilation = this.SemanticModel.Compilation;
177202

178203
if (variablesToUseAsReturnValue.IsEmpty)
179204
return compilation.GetSpecialType(SpecialType.System_Void);

src/Features/Core/Portable/ExtractMethod/MethodExtractor.CodeGenerator.cs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,22 @@ protected TExecutableStatementSyntax GetStatementContainingInvocationToExtracted
376376
public ITypeSymbol GetFinalReturnType()
377377
{
378378
return _finalReturnType ??= ComputeFinalReturnType();
379-
}
380379

381-
private ITypeSymbol ComputeFinalReturnType()
382-
{
383-
throw new NotImplementedException();
380+
ITypeSymbol ComputeFinalReturnType()
381+
{
382+
var coreType = this.AnalyzerResult.CoreReturnType;
383+
if (this.SelectionResult.ContainsAwaitExpression())
384+
{
385+
// If we're awaiting, then we're going to be returning a task of some sort. Convert `void` to
386+
// `Task` and any other T to `Task<T>`.
387+
var compilation = this.SemanticDocument.SemanticModel.Compilation;
388+
return coreType.SpecialType == SpecialType.System_Void
389+
? compilation.TaskType()
390+
: compilation.TaskOfTType().Construct(coreType);
391+
}
392+
393+
return coreType;
394+
}
384395
}
385396
}
386397
}

src/Features/Core/Portable/ExtractMethod/SelectionResult.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ private bool CheckNodesInSelection(Func<ISyntaxFacts, SyntaxNode, bool> predicat
144144
if (syntaxFacts.IsAnonymousOrLocalFunction(current))
145145
continue;
146146

147-
if (predicate(syntaxFacts, current))
147+
if (current.Span.OverlapsWith(span) && predicate(syntaxFacts, current))
148148
return true;
149149

150150
// Only dive into child nodes within the span being extracted.

src/Features/VisualBasic/Portable/ExtractMethod/VisualBasicMethodExtractor.VisualBasicCodeGenerator.vb

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.ExtractMethod
173173
isShared = True
174174
End If
175175

176-
Dim isAsync = Me.SelectionResult.CreateAsyncMethod()
176+
Dim isAsync = Me.SelectionResult.ContainsAwaitExpression()
177177

178178
Return New DeclarationModifiers(isStatic:=isShared, isAsync:=isAsync)
179179
End Function
@@ -361,14 +361,12 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.ExtractMethod
361361
Dim invocation = SyntaxFactory.InvocationExpression(
362362
methodExpression, SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments)))
363363

364-
If Me.SelectionResult.CreateAsyncMethod() Then
365-
If Me.SelectionResult.ShouldCallConfigureAwaitFalse() Then
366-
If Me.GetFinalReturnType().
367-
GetMembers(NameOf(Task.ConfigureAwait)).
368-
OfType(Of IMethodSymbol).
369-
Any(Function(method) method.Parameters.Length = 1 AndAlso method.Parameters(0).Type.SpecialType = SpecialType.System_Boolean) Then
370-
371-
invocation = SyntaxFactory.InvocationExpression(
364+
' If we're extracting any code that contained an 'await' then we'll have to await the new method
365+
' we're calling as well. If we also see any use of .ConfigureAwait(false) in the extracted code,
366+
' keep that pattern on the await expression we produce.
367+
If Me.SelectionResult.ContainsAwaitExpression() Then
368+
If Me.SelectionResult.ContainsConfigureAwaitFalse() Then
369+
invocation = SyntaxFactory.InvocationExpression(
372370
SyntaxFactory.MemberAccessExpression(
373371
SyntaxKind.SimpleMemberAccessExpression,
374372
invocation,
@@ -379,7 +377,6 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.ExtractMethod
379377
SyntaxFactory.LiteralExpression(
380378
SyntaxKind.FalseLiteralExpression,
381379
SyntaxFactory.Token(SyntaxKind.FalseKeyword))))))
382-
End If
383380
End If
384381

385382
Return SyntaxFactory.AwaitExpression(invocation)

src/Features/VisualBasic/Portable/ExtractMethod/VisualBasicSelectionResult.vb

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,6 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.ExtractMethod
4343
MyBase.New(document, selectionType, finalSpan)
4444
End Sub
4545

46-
Protected Overrides Function UnderAnonymousOrLocalMethod(token As SyntaxToken, firstToken As SyntaxToken, lastToken As SyntaxToken) As Boolean
47-
Dim current = token.Parent
48-
49-
While current IsNot Nothing
50-
If TypeOf current Is DeclarationStatementSyntax OrElse
51-
TypeOf current Is LambdaExpressionSyntax Then
52-
Exit While
53-
End If
54-
55-
current = current.Parent
56-
End While
57-
58-
If current Is Nothing OrElse TypeOf current Is DeclarationStatementSyntax Then
59-
Return False
60-
End If
61-
62-
' make sure selection contains the lambda
63-
Return firstToken.SpanStart <= current.GetFirstToken().SpanStart AndAlso
64-
current.GetLastToken().Span.End <= lastToken.Span.End
65-
End Function
66-
6746
Public Overrides Function GetOutermostCallSiteContainerToProcess(cancellationToken As CancellationToken) As SyntaxNode
6847
If Me.IsExtractMethodOnExpression Then
6948
Dim container = Me.InnermostStatementContainer()

0 commit comments

Comments
 (0)