@@ -435,7 +435,6 @@ namespace ts.codefix {
435
435
}
436
436
return varDeclOrAssignment ;
437
437
438
- case SyntaxKind . FunctionDeclaration :
439
438
case SyntaxKind . FunctionExpression :
440
439
case SyntaxKind . ArrowFunction :
441
440
// Arrow functions with block bodies { } will enter this control flow
@@ -457,11 +456,11 @@ namespace ts.codefix {
457
456
}
458
457
459
458
return shouldReturn ? getSynthesizedDeepClones ( createNodeArray ( refactoredStmts ) ) :
460
- removeReturns ( createNodeArray ( refactoredStmts ) , prevArgName ! . identifier , transformer . constIdentifiers , seenReturnStatement ) ;
459
+ removeReturns ( createNodeArray ( refactoredStmts ) , prevArgName ! . identifier , transformer , seenReturnStatement ) ;
461
460
}
462
461
else {
463
- const funcBody = ( < ArrowFunction > func ) . body ;
464
- const innerRetStmts = getReturnStatementsWithPromiseHandlers ( createReturn ( funcBody as Expression ) ) ;
462
+ const funcBody = cast ( ( < ArrowFunction > func ) . body , isExpression ) ;
463
+ const innerRetStmts = getReturnStatementsWithPromiseHandlers ( createReturn ( funcBody ) ) ;
465
464
const innerCbBody = getInnerTransformationBody ( transformer , innerRetStmts , prevArgName ) ;
466
465
467
466
if ( innerCbBody . length > 0 ) {
@@ -471,14 +470,16 @@ namespace ts.codefix {
471
470
if ( ! shouldReturn ) {
472
471
const type = transformer . checker . getTypeAtLocation ( func ) ;
473
472
const returnType = getLastCallSignature ( type , transformer . checker ) ! . getReturnType ( ) ;
474
- const varDeclOrAssignment = createVariableDeclarationOrAssignment ( prevArgName , getSynthesizedDeepClone ( funcBody ) as Expression , transformer ) ;
473
+ const rightHandSide = getSynthesizedDeepClone ( funcBody ) ;
474
+ const possiblyAwaitedRightHandSide = isPromiseReturningExpression ( funcBody , transformer . checker ) ? createAwait ( rightHandSide ) : rightHandSide ;
475
+ const varDeclOrAssignment = createVariableDeclarationOrAssignment ( prevArgName , possiblyAwaitedRightHandSide , transformer ) ;
475
476
if ( prevArgName ) {
476
477
prevArgName . types . push ( returnType ) ;
477
478
}
478
479
return varDeclOrAssignment ;
479
480
}
480
481
else {
481
- return createNodeArray ( [ createReturn ( getSynthesizedDeepClone ( funcBody ) as Expression ) ] ) ;
482
+ return createNodeArray ( [ createReturn ( getSynthesizedDeepClone ( funcBody ) ) ] ) ;
482
483
}
483
484
}
484
485
default :
@@ -495,13 +496,14 @@ namespace ts.codefix {
495
496
}
496
497
497
498
498
- function removeReturns ( stmts : NodeArray < Statement > , prevArgName : Identifier , constIdentifiers : Identifier [ ] , seenReturnStatement : boolean ) : NodeArray < Statement > {
499
+ function removeReturns ( stmts : NodeArray < Statement > , prevArgName : Identifier , transformer : Transformer , seenReturnStatement : boolean ) : NodeArray < Statement > {
499
500
const ret : Statement [ ] = [ ] ;
500
501
for ( const stmt of stmts ) {
501
502
if ( isReturnStatement ( stmt ) ) {
502
503
if ( stmt . expression ) {
504
+ const possiblyAwaitedExpression = isPromiseReturningExpression ( stmt . expression , transformer . checker ) ? createAwait ( stmt . expression ) : stmt . expression ;
503
505
ret . push ( createVariableStatement ( /*modifiers*/ undefined ,
504
- ( createVariableDeclarationList ( [ createVariableDeclaration ( prevArgName , /*type*/ undefined , stmt . expression ) ] , getFlagOfIdentifier ( prevArgName , constIdentifiers ) ) ) ) ) ;
506
+ ( createVariableDeclarationList ( [ createVariableDeclaration ( prevArgName , /*type*/ undefined , possiblyAwaitedExpression ) ] , getFlagOfIdentifier ( prevArgName , transformer . constIdentifiers ) ) ) ) ) ;
505
507
}
506
508
}
507
509
else {
@@ -512,7 +514,7 @@ namespace ts.codefix {
512
514
// if block has no return statement, need to define prevArgName as undefined to prevent undeclared variables
513
515
if ( ! seenReturnStatement ) {
514
516
ret . push ( createVariableStatement ( /*modifiers*/ undefined ,
515
- ( createVariableDeclarationList ( [ createVariableDeclaration ( prevArgName , /*type*/ undefined , createIdentifier ( "undefined" ) ) ] , getFlagOfIdentifier ( prevArgName , constIdentifiers ) ) ) ) ) ;
517
+ ( createVariableDeclarationList ( [ createVariableDeclaration ( prevArgName , /*type*/ undefined , createIdentifier ( "undefined" ) ) ] , getFlagOfIdentifier ( prevArgName , transformer . constIdentifiers ) ) ) ) ) ;
516
518
}
517
519
518
520
return createNodeArray ( ret ) ;
0 commit comments