@@ -31,7 +31,7 @@ namespace ts.codefix {
31
31
errorCodes,
32
32
getCodeActions : context => {
33
33
const { sourceFile, errorCode, span, cancellationToken, program } = context ;
34
- const expression = getAwaitableExpression ( sourceFile , errorCode , span , cancellationToken , program ) ;
34
+ const expression = getFixableErrorSpanExpression ( sourceFile , errorCode , span , cancellationToken , program ) ;
35
35
if ( ! expression ) {
36
36
return ;
37
37
}
@@ -45,32 +45,40 @@ namespace ts.codefix {
45
45
getAllCodeActions : context => {
46
46
const { sourceFile, program, cancellationToken } = context ;
47
47
const checker = context . program . getTypeChecker ( ) ;
48
+ const fixedDeclarations = createMap < true > ( ) ;
48
49
return codeFixAll ( context , errorCodes , ( t , diagnostic ) => {
49
- const expression = getAwaitableExpression ( sourceFile , diagnostic . code , diagnostic , cancellationToken , program ) ;
50
+ const expression = getFixableErrorSpanExpression ( sourceFile , diagnostic . code , diagnostic , cancellationToken , program ) ;
50
51
if ( ! expression ) {
51
52
return ;
52
53
}
53
54
const trackChanges : ContextualTrackChangesFunction = cb => ( cb ( t ) , [ ] ) ;
54
- return getDeclarationSiteFix ( context , expression , diagnostic . code , checker , trackChanges )
55
- || getUseSiteFix ( context , expression , diagnostic . code , checker , trackChanges ) ;
55
+ return getDeclarationSiteFix ( context , expression , diagnostic . code , checker , trackChanges , fixedDeclarations )
56
+ || getUseSiteFix ( context , expression , diagnostic . code , checker , trackChanges , fixedDeclarations ) ;
56
57
} ) ;
57
58
} ,
58
59
} ) ;
59
60
60
- function getDeclarationSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction ) {
61
- const { sourceFile } = context ;
62
- const awaitableInitializer = findAwaitableInitializer ( expression , sourceFile , checker ) ;
63
- if ( awaitableInitializer ) {
64
- const initializerChanges = trackChanges ( t => makeChange ( t , errorCode , sourceFile , checker , awaitableInitializer ) ) ;
61
+ function getDeclarationSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction , fixedDeclarations ?: Map < true > ) {
62
+ const { sourceFile, program, cancellationToken } = context ;
63
+ const awaitableInitializers = findAwaitableInitializers ( expression , sourceFile , cancellationToken , program , checker ) ;
64
+ if ( awaitableInitializers ) {
65
+ const initializerChanges = trackChanges ( t => {
66
+ forEach ( awaitableInitializers . initializers , ( { expression } ) => makeChange ( t , errorCode , sourceFile , checker , expression , fixedDeclarations ) ) ;
67
+ if ( fixedDeclarations && awaitableInitializers . needsSecondPassForFixAll ) {
68
+ makeChange ( t , errorCode , sourceFile , checker , expression , fixedDeclarations ) ;
69
+ }
70
+ } ) ;
65
71
return createCodeFixActionNoFixId (
66
72
"addMissingAwaitToInitializer" ,
67
73
initializerChanges ,
68
- [ Diagnostics . Add_await_to_initializer_for_0 , expression . getText ( sourceFile ) ] ) ;
74
+ awaitableInitializers . initializers . length === 1
75
+ ? [ Diagnostics . Add_await_to_initializer_for_0 , awaitableInitializers . initializers [ 0 ] . declarationSymbol . name ]
76
+ : Diagnostics . Add_await_to_initializers ) ;
69
77
}
70
78
}
71
79
72
- function getUseSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction ) {
73
- const changes = trackChanges ( t => makeChange ( t , errorCode , context . sourceFile , checker , expression ) ) ;
80
+ function getUseSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction , fixedDeclarations ?: Map < true > ) {
81
+ const changes = trackChanges ( t => makeChange ( t , errorCode , context . sourceFile , checker , expression , fixedDeclarations ) ) ;
74
82
return createCodeFixAction ( fixId , changes , Diagnostics . Add_await , fixId , Diagnostics . Fix_all_expressions_possibly_missing_await ) ;
75
83
}
76
84
@@ -84,7 +92,7 @@ namespace ts.codefix {
84
92
some ( relatedInformation , related => related . code === Diagnostics . Did_you_forget_to_use_await . code ) ) ;
85
93
}
86
94
87
- function getAwaitableExpression ( sourceFile : SourceFile , errorCode : number , span : TextSpan , cancellationToken : CancellationToken , program : Program ) : Expression | undefined {
95
+ function getFixableErrorSpanExpression ( sourceFile : SourceFile , errorCode : number , span : TextSpan , cancellationToken : CancellationToken , program : Program ) : Expression | undefined {
88
96
const token = getTokenAtPosition ( sourceFile , span . start ) ;
89
97
// Checker has already done work to determine that await might be possible, and has attached
90
98
// related info to the node, so start by finding the expression that exactly matches up
@@ -103,38 +111,117 @@ namespace ts.codefix {
103
111
: undefined ;
104
112
}
105
113
106
- function findAwaitableInitializer ( expression : Node , sourceFile : SourceFile , checker : TypeChecker ) : Expression | undefined {
107
- if ( ! isIdentifier ( expression ) ) {
108
- return ;
109
- }
114
+ interface AwaitableInitializer {
115
+ expression : Expression ;
116
+ declarationSymbol : Symbol ;
117
+ }
110
118
111
- const symbol = checker . getSymbolAtLocation ( expression ) ;
112
- if ( ! symbol ) {
119
+ interface AwaitableInitializers {
120
+ initializers : readonly AwaitableInitializer [ ] ;
121
+ needsSecondPassForFixAll : boolean ;
122
+ }
123
+
124
+ function findAwaitableInitializers (
125
+ expression : Node ,
126
+ sourceFile : SourceFile ,
127
+ cancellationToken : CancellationToken ,
128
+ program : Program ,
129
+ checker : TypeChecker ,
130
+ ) : AwaitableInitializers | undefined {
131
+ const identifiers = getIdentifiersFromErrorSpanExpression ( expression , checker ) ;
132
+ if ( ! identifiers ) {
113
133
return ;
114
134
}
115
135
116
- const declaration = tryCast ( symbol . valueDeclaration , isVariableDeclaration ) ;
117
- const variableName = tryCast ( declaration && declaration . name , isIdentifier ) ;
118
- const variableStatement = getAncestor ( declaration , SyntaxKind . VariableStatement ) ;
119
- if ( ! declaration || ! variableStatement ||
120
- declaration . type ||
121
- ! declaration . initializer ||
122
- variableStatement . getSourceFile ( ) !== sourceFile ||
123
- hasModifier ( variableStatement , ModifierFlags . Export ) ||
124
- ! variableName ||
125
- ! isInsideAwaitableBody ( declaration . initializer ) ) {
126
- return ;
136
+ let isCompleteFix = identifiers . isCompleteFix ;
137
+ let initializers : AwaitableInitializer [ ] | undefined ;
138
+ for ( const identifier of identifiers . identifiers ) {
139
+ const symbol = checker . getSymbolAtLocation ( identifier ) ;
140
+ if ( ! symbol ) {
141
+ continue ;
142
+ }
143
+
144
+ const declaration = tryCast ( symbol . valueDeclaration , isVariableDeclaration ) ;
145
+ const variableName = declaration && tryCast ( declaration . name , isIdentifier ) ;
146
+ const variableStatement = getAncestor ( declaration , SyntaxKind . VariableStatement ) ;
147
+ if ( ! declaration || ! variableStatement ||
148
+ declaration . type ||
149
+ ! declaration . initializer ||
150
+ variableStatement . getSourceFile ( ) !== sourceFile ||
151
+ hasModifier ( variableStatement , ModifierFlags . Export ) ||
152
+ ! variableName ||
153
+ ! isInsideAwaitableBody ( declaration . initializer ) ) {
154
+ isCompleteFix = false ;
155
+ continue ;
156
+ }
157
+
158
+ const diagnostics = program . getSemanticDiagnostics ( sourceFile , cancellationToken ) ;
159
+ const isUsedElsewhere = FindAllReferences . Core . eachSymbolReferenceInFile ( variableName , checker , sourceFile , reference => {
160
+ return identifier !== reference && ! symbolReferenceIsAlsoMissingAwait ( reference , diagnostics , sourceFile , checker ) ;
161
+ } ) ;
162
+
163
+ if ( isUsedElsewhere ) {
164
+ isCompleteFix = false ;
165
+ continue ;
166
+ }
167
+
168
+ ( initializers || ( initializers = [ ] ) ) . push ( {
169
+ expression : declaration . initializer ,
170
+ declarationSymbol : symbol ,
171
+ } ) ;
127
172
}
173
+ return initializers && {
174
+ initializers,
175
+ needsSecondPassForFixAll : ! isCompleteFix ,
176
+ } ;
177
+ }
128
178
129
- const isUsedElsewhere = FindAllReferences . Core . eachSymbolReferenceInFile ( variableName , checker , sourceFile , identifier => {
130
- return identifier !== expression ;
131
- } ) ;
179
+ interface Identifiers {
180
+ identifiers : readonly Identifier [ ] ;
181
+ isCompleteFix : boolean ;
182
+ }
132
183
133
- if ( isUsedElsewhere ) {
134
- return ;
184
+ function getIdentifiersFromErrorSpanExpression ( expression : Node , checker : TypeChecker ) : Identifiers | undefined {
185
+ if ( isPropertyAccessExpression ( expression . parent ) && isIdentifier ( expression . parent . expression ) ) {
186
+ return { identifiers : [ expression . parent . expression ] , isCompleteFix : true } ;
187
+ }
188
+ if ( isIdentifier ( expression ) ) {
189
+ return { identifiers : [ expression ] , isCompleteFix : true } ;
190
+ }
191
+ if ( isBinaryExpression ( expression ) ) {
192
+ let sides : Identifier [ ] | undefined ;
193
+ let isCompleteFix = true ;
194
+ for ( const side of [ expression . left , expression . right ] ) {
195
+ const type = checker . getTypeAtLocation ( side ) ;
196
+ if ( checker . getPromisedTypeOfPromise ( type ) ) {
197
+ if ( ! isIdentifier ( side ) ) {
198
+ isCompleteFix = false ;
199
+ continue ;
200
+ }
201
+ ( sides || ( sides = [ ] ) ) . push ( side ) ;
202
+ }
203
+ }
204
+ return sides && { identifiers : sides , isCompleteFix } ;
135
205
}
206
+ }
207
+
208
+ function symbolReferenceIsAlsoMissingAwait ( reference : Identifier , diagnostics : readonly Diagnostic [ ] , sourceFile : SourceFile , checker : TypeChecker ) {
209
+ const errorNode = isPropertyAccessExpression ( reference . parent ) ? reference . parent . name :
210
+ isBinaryExpression ( reference . parent ) ? reference . parent :
211
+ reference ;
212
+ const diagnostic = find ( diagnostics , diagnostic =>
213
+ diagnostic . start === errorNode . getStart ( sourceFile ) &&
214
+ diagnostic . start + diagnostic . length ! === errorNode . getEnd ( ) ) ;
136
215
137
- return declaration . initializer ;
216
+ return diagnostic && contains ( errorCodes , diagnostic . code ) ||
217
+ // A Promise is usually not correct in a binary expression (it’s not valid
218
+ // in an arithmetic expression and an equality comparison seems unusual),
219
+ // but if the other side of the binary expression has an error, the side
220
+ // is typed `any` which will squash the error that would identify this
221
+ // Promise as an invalid operand. So if the whole binary expression is
222
+ // typed `any` as a result, there is a strong likelihood that this Promise
223
+ // is accidentally missing `await`.
224
+ checker . getTypeAtLocation ( errorNode ) . flags & TypeFlags . Any ;
138
225
}
139
226
140
227
function isInsideAwaitableBody ( node : Node ) {
@@ -147,26 +234,48 @@ namespace ts.codefix {
147
234
ancestor . parent . kind === SyntaxKind . MethodDeclaration ) ) ;
148
235
}
149
236
150
- function makeChange ( changeTracker : textChanges . ChangeTracker , errorCode : number , sourceFile : SourceFile , checker : TypeChecker , insertionSite : Expression ) {
237
+ function makeChange ( changeTracker : textChanges . ChangeTracker , errorCode : number , sourceFile : SourceFile , checker : TypeChecker , insertionSite : Expression , fixedDeclarations ?: Map < true > ) {
151
238
if ( isBinaryExpression ( insertionSite ) ) {
152
- const { left, right } = insertionSite ;
153
- const leftType = checker . getTypeAtLocation ( left ) ;
154
- const rightType = checker . getTypeAtLocation ( right ) ;
155
- const newLeft = checker . getPromisedTypeOfPromise ( leftType ) ? createAwait ( left ) : left ;
156
- const newRight = checker . getPromisedTypeOfPromise ( rightType ) ? createAwait ( right ) : right ;
157
- changeTracker . replaceNode ( sourceFile , left , newLeft ) ;
158
- changeTracker . replaceNode ( sourceFile , right , newRight ) ;
239
+ for ( const side of [ insertionSite . left , insertionSite . right ] ) {
240
+ if ( fixedDeclarations && isIdentifier ( side ) ) {
241
+ const symbol = checker . getSymbolAtLocation ( side ) ;
242
+ if ( symbol && fixedDeclarations . has ( getSymbolId ( symbol ) . toString ( ) ) ) {
243
+ continue ;
244
+ }
245
+ }
246
+ const type = checker . getTypeAtLocation ( side ) ;
247
+ const newNode = checker . getPromisedTypeOfPromise ( type ) ? createAwait ( side ) : side ;
248
+ changeTracker . replaceNode ( sourceFile , side , newNode ) ;
249
+ }
159
250
}
160
251
else if ( errorCode === propertyAccessCode && isPropertyAccessExpression ( insertionSite . parent ) ) {
252
+ if ( fixedDeclarations && isIdentifier ( insertionSite . parent . expression ) ) {
253
+ const symbol = checker . getSymbolAtLocation ( insertionSite . parent . expression ) ;
254
+ if ( symbol && fixedDeclarations . has ( getSymbolId ( symbol ) . toString ( ) ) ) {
255
+ return ;
256
+ }
257
+ }
161
258
changeTracker . replaceNode (
162
259
sourceFile ,
163
260
insertionSite . parent . expression ,
164
261
createParen ( createAwait ( insertionSite . parent . expression ) ) ) ;
165
262
}
166
263
else if ( contains ( callableConstructableErrorCodes , errorCode ) && isCallOrNewExpression ( insertionSite . parent ) ) {
264
+ if ( fixedDeclarations && isIdentifier ( insertionSite ) ) {
265
+ const symbol = checker . getSymbolAtLocation ( insertionSite ) ;
266
+ if ( symbol && fixedDeclarations . has ( getSymbolId ( symbol ) . toString ( ) ) ) {
267
+ return ;
268
+ }
269
+ }
167
270
changeTracker . replaceNode ( sourceFile , insertionSite , createParen ( createAwait ( insertionSite ) ) ) ;
168
271
}
169
272
else {
273
+ if ( fixedDeclarations && isVariableDeclaration ( insertionSite . parent ) && isIdentifier ( insertionSite . parent . name ) ) {
274
+ const symbol = checker . getSymbolAtLocation ( insertionSite . parent . name ) ;
275
+ if ( symbol && ! addToSeen ( fixedDeclarations , getSymbolId ( symbol ) ) ) {
276
+ return ;
277
+ }
278
+ }
170
279
changeTracker . replaceNode ( sourceFile , insertionSite , createAwait ( insertionSite ) ) ;
171
280
}
172
281
}
0 commit comments