Skip to content

Commit 1b9507a

Browse files
author
Benjamin Lichtman
committed
Wrap expressions returned from promises in awaits when appropriate
1 parent 0016fd7 commit 1b9507a

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/services/codefixes/convertToAsyncFunction.ts

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,6 @@ namespace ts.codefix {
435435
}
436436
return varDeclOrAssignment;
437437

438-
case SyntaxKind.FunctionDeclaration:
439438
case SyntaxKind.FunctionExpression:
440439
case SyntaxKind.ArrowFunction:
441440
// Arrow functions with block bodies { } will enter this control flow
@@ -457,11 +456,11 @@ namespace ts.codefix {
457456
}
458457

459458
return shouldReturn ? getSynthesizedDeepClones(createNodeArray(refactoredStmts)) :
460-
removeReturns(createNodeArray(refactoredStmts), prevArgName!.identifier, transformer.constIdentifiers, seenReturnStatement);
459+
removeReturns(createNodeArray(refactoredStmts), prevArgName!.identifier, transformer, seenReturnStatement);
461460
}
462461
else {
463-
const funcBody = (<ArrowFunction>func).body;
464-
const innerRetStmts = getReturnStatementsWithPromiseHandlers(createReturn(funcBody as Expression));
462+
const funcBody = cast((<ArrowFunction>func).body, isExpression);
463+
const innerRetStmts = getReturnStatementsWithPromiseHandlers(createReturn(funcBody));
465464
const innerCbBody = getInnerTransformationBody(transformer, innerRetStmts, prevArgName);
466465

467466
if (innerCbBody.length > 0) {
@@ -471,14 +470,16 @@ namespace ts.codefix {
471470
if (!shouldReturn) {
472471
const type = transformer.checker.getTypeAtLocation(func);
473472
const returnType = getLastCallSignature(type, transformer.checker)!.getReturnType();
474-
const varDeclOrAssignment = createVariableDeclarationOrAssignment(prevArgName, getSynthesizedDeepClone(funcBody) as Expression, transformer);
473+
const rightHandSide = getSynthesizedDeepClone(funcBody);
474+
const possiblyAwaitedRightHandSide = isPromiseReturningExpression(funcBody, transformer.checker) ? createAwait(rightHandSide) : rightHandSide;
475+
const varDeclOrAssignment = createVariableDeclarationOrAssignment(prevArgName, possiblyAwaitedRightHandSide, transformer);
475476
if (prevArgName) {
476477
prevArgName.types.push(returnType);
477478
}
478479
return varDeclOrAssignment;
479480
}
480481
else {
481-
return createNodeArray([createReturn(getSynthesizedDeepClone(funcBody) as Expression)]);
482+
return createNodeArray([createReturn(getSynthesizedDeepClone(funcBody))]);
482483
}
483484
}
484485
default:
@@ -495,13 +496,14 @@ namespace ts.codefix {
495496
}
496497

497498

498-
function removeReturns(stmts: NodeArray<Statement>, prevArgName: Identifier, constIdentifiers: Identifier[], seenReturnStatement: boolean): NodeArray<Statement> {
499+
function removeReturns(stmts: NodeArray<Statement>, prevArgName: Identifier, transformer: Transformer, seenReturnStatement: boolean): NodeArray<Statement> {
499500
const ret: Statement[] = [];
500501
for (const stmt of stmts) {
501502
if (isReturnStatement(stmt)) {
502503
if (stmt.expression) {
504+
const possiblyAwaitedExpression = isPromiseReturningExpression(stmt.expression, transformer.checker) ? createAwait(stmt.expression) : stmt.expression;
503505
ret.push(createVariableStatement(/*modifiers*/ undefined,
504-
(createVariableDeclarationList([createVariableDeclaration(prevArgName, /*type*/ undefined, stmt.expression)], getFlagOfIdentifier(prevArgName, constIdentifiers)))));
506+
(createVariableDeclarationList([createVariableDeclaration(prevArgName, /*type*/ undefined, possiblyAwaitedExpression)], getFlagOfIdentifier(prevArgName, transformer.constIdentifiers)))));
505507
}
506508
}
507509
else {
@@ -512,7 +514,7 @@ namespace ts.codefix {
512514
// if block has no return statement, need to define prevArgName as undefined to prevent undeclared variables
513515
if (!seenReturnStatement) {
514516
ret.push(createVariableStatement(/*modifiers*/ undefined,
515-
(createVariableDeclarationList([createVariableDeclaration(prevArgName, /*type*/ undefined, createIdentifier("undefined"))], getFlagOfIdentifier(prevArgName, constIdentifiers)))));
517+
(createVariableDeclarationList([createVariableDeclaration(prevArgName, /*type*/ undefined, createIdentifier("undefined"))], getFlagOfIdentifier(prevArgName, transformer.constIdentifiers)))));
516518
}
517519

518520
return createNodeArray(ret);

0 commit comments

Comments
 (0)