@@ -25471,8 +25471,8 @@ namespace ts {
2547125471
2547225472 if (functionFlags & FunctionFlags.Async) { // Async function or AsyncGenerator function
2547325473 // Get the awaited type without the `Awaited<T>` alias
25474- const contextualAwaitedType = mapType(contextualReturnType, getAwaitedType );
25475- return contextualAwaitedType && getUnionType([unwrapAwaitedType( contextualAwaitedType) , createPromiseLikeType(contextualAwaitedType)]);
25474+ const contextualAwaitedType = mapType(contextualReturnType, getAwaitedTypeNoAlias );
25475+ return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
2547625476 }
2547725477
2547825478 return contextualReturnType; // Regular function or Generator function
@@ -25484,8 +25484,8 @@ namespace ts {
2548425484 function getContextualTypeForAwaitOperand(node: AwaitExpression, contextFlags?: ContextFlags): Type | undefined {
2548525485 const contextualType = getContextualType(node, contextFlags);
2548625486 if (contextualType) {
25487- const contextualAwaitedType = getAwaitedType (contextualType);
25488- return contextualAwaitedType && getUnionType([unwrapAwaitedType( contextualAwaitedType) , createPromiseLikeType(contextualAwaitedType)]);
25487+ const contextualAwaitedType = getAwaitedTypeNoAlias (contextualType);
25488+ return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
2548925489 }
2549025490 return undefined;
2549125491 }
@@ -31158,7 +31158,8 @@ namespace ts {
3115831158 const globalPromiseType = getGlobalPromiseType(/*reportErrors*/ true);
3115931159 if (globalPromiseType !== emptyGenericType) {
3116031160 // if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
31161- promisedType = getAwaitedType(promisedType) || unknownType;
31161+ // Unwrap an `Awaited<T>` to `T` to improve inference.
31162+ promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
3116231163 return createTypeReference(globalPromiseType, [promisedType]);
3116331164 }
3116431165
@@ -31170,7 +31171,8 @@ namespace ts {
3117031171 const globalPromiseLikeType = getGlobalPromiseLikeType(/*reportErrors*/ true);
3117131172 if (globalPromiseLikeType !== emptyGenericType) {
3117231173 // if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
31173- promisedType = getAwaitedType(promisedType) || unknownType;
31174+ // Unwrap an `Awaited<T>` to `T` to improve inference.
31175+ promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
3117431176 return createTypeReference(globalPromiseLikeType, [promisedType]);
3117531177 }
3117631178
@@ -31227,7 +31229,7 @@ namespace ts {
3122731229 // Promise/A+ compatible implementation will always assimilate any foreign promise, so the
3122831230 // return type of the body should be unwrapped to its awaited type, which we will wrap in
3122931231 // the native Promise<T> type later in this function.
31230- returnType = checkAwaitedType(returnType, /*errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31232+ returnType = unwrapAwaitedType( checkAwaitedType(returnType, /*withAlias*/ false, /* errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member) );
3123131233 }
3123231234 }
3123331235 else if (isGenerator) { // Generator or AsyncGenerator function
@@ -31460,7 +31462,7 @@ namespace ts {
3146031462 // Promise/A+ compatible implementation will always assimilate any foreign promise, so the
3146131463 // return type of the body should be unwrapped to its awaited type, which should be wrapped in
3146231464 // the native Promise<T> type by the caller.
31463- type = checkAwaitedType(type, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31465+ type = unwrapAwaitedType( checkAwaitedType(type, /*withAlias*/ false, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member) );
3146431466 }
3146531467 if (type.flags & TypeFlags.Never) {
3146631468 hasReturnOfTypeNever = true;
@@ -31662,7 +31664,7 @@ namespace ts {
3166231664 const returnOrPromisedType = returnType && unwrapReturnType(returnType, functionFlags);
3166331665 if (returnOrPromisedType) {
3166431666 if ((functionFlags & FunctionFlags.AsyncGenerator) === FunctionFlags.Async) { // Async function
31665- const awaitedType = checkAwaitedType(exprType, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31667+ const awaitedType = checkAwaitedType(exprType, /*withAlias*/ false, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
3166631668 checkTypeAssignableToAndOptionallyElaborate(awaitedType, returnOrPromisedType, node.body, node.body);
3166731669 }
3166831670 else { // Normal function
@@ -31879,7 +31881,7 @@ namespace ts {
3187931881 }
3188031882
3188131883 const operandType = checkExpression(node.expression);
31882- const awaitedType = checkAwaitedType(operandType, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
31884+ const awaitedType = checkAwaitedType(operandType, /*withAlias*/ true, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
3188331885 if (awaitedType === operandType && awaitedType !== errorType && !(operandType.flags & TypeFlags.AnyOrUnknown)) {
3188431886 addErrorOrSuggestion(/*isError*/ false, createDiagnosticForNode(node, Diagnostics.await_has_no_effect_on_the_type_of_this_expression));
3188531887 }
@@ -32831,8 +32833,8 @@ namespace ts {
3283132833 let wouldWorkWithAwait = false;
3283232834 const errNode = errorNode || operatorToken;
3283332835 if (isRelated) {
32834- const awaitedLeftType = unwrapAwaitedType(getAwaitedType( leftType) );
32835- const awaitedRightType = unwrapAwaitedType(getAwaitedType( rightType) );
32836+ const awaitedLeftType = getAwaitedTypeNoAlias( leftType);
32837+ const awaitedRightType = getAwaitedTypeNoAlias( rightType);
3283632838 wouldWorkWithAwait = !(awaitedLeftType === leftType && awaitedRightType === rightType)
3283732839 && !!(awaitedLeftType && awaitedRightType)
3283832840 && isRelated(awaitedLeftType, awaitedRightType);
@@ -34914,12 +34916,15 @@ namespace ts {
3491434916 /**
3491534917 * Gets the "awaited type" of a type.
3491634918 * @param type The type to await.
34919+ * @param withAlias When `true`, wraps the "awaited type" in `Awaited<T>` if needed.
3491734920 * @remarks The "awaited type" of an expression is its "promised type" if the expression is a
3491834921 * Promise-like type; otherwise, it is the type of the expression. This is used to reflect
3491934922 * The runtime behavior of the `await` keyword.
3492034923 */
34921- function checkAwaitedType(type: Type, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
34922- const awaitedType = getAwaitedType(type, errorNode, diagnosticMessage, arg0);
34924+ function checkAwaitedType(type: Type, withAlias: boolean, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
34925+ const awaitedType = withAlias ?
34926+ getAwaitedType(type, errorNode, diagnosticMessage, arg0) :
34927+ getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
3492334928 return awaitedType || errorType;
3492434929 }
3492534930
@@ -34953,10 +34958,7 @@ namespace ts {
3495334958 /**
3495434959 * For a generic `Awaited<T>`, gets `T`.
3495534960 */
34956- function unwrapAwaitedType(type: Type): Type;
34957- function unwrapAwaitedType(type: Type | undefined): Type | undefined;
34958- function unwrapAwaitedType(type: Type | undefined) {
34959- if (!type) return undefined;
34961+ function unwrapAwaitedType(type: Type) {
3496034962 return type.flags & TypeFlags.Union ? mapType(type, unwrapAwaitedType) :
3496134963 isAwaitedTypeInstantiation(type) ? type.aliasTypeArguments[0] :
3496234964 type;
@@ -35011,6 +35013,16 @@ namespace ts {
3501135013 * This is used to reflect the runtime behavior of the `await` keyword.
3501235014 */
3501335015 function getAwaitedType(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
35016+ const awaitedType = getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
35017+ return awaitedType && createAwaitedTypeIfNeeded(awaitedType);
35018+ }
35019+
35020+ /**
35021+ * Gets the "awaited type" of a type without introducing an `Awaited<T>` wrapper.
35022+ *
35023+ * @see {@link getAwaitedType}
35024+ */
35025+ function getAwaitedTypeNoAlias(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
3501435026 if (isTypeAny(type)) {
3501535027 return type;
3501635028 }
@@ -35023,14 +35035,13 @@ namespace ts {
3502335035 // If we've already cached an awaited type, return a possible `Awaited<T>` for it.
3502435036 const typeAsAwaitable = type as PromiseOrAwaitableType;
3502535037 if (typeAsAwaitable.awaitedTypeOfType) {
35026- return createAwaitedTypeIfNeeded( typeAsAwaitable.awaitedTypeOfType) ;
35038+ return typeAsAwaitable.awaitedTypeOfType;
3502735039 }
3502835040
3502935041 // For a union, get a union of the awaited types of each constituent.
3503035042 if (type.flags & TypeFlags.Union) {
35031- const mapper = errorNode ? (constituentType: Type) => getAwaitedType(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedType;
35032- typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
35033- return typeAsAwaitable.awaitedTypeOfType && createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType);
35043+ const mapper = errorNode ? (constituentType: Type) => getAwaitedTypeNoAlias(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedTypeNoAlias;
35044+ return typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
3503435045 }
3503535046
3503635047 const promisedType = getPromisedTypeOfPromise(type);
@@ -35078,14 +35089,14 @@ namespace ts {
3507835089 // Keep track of the type we're about to unwrap to avoid bad recursive promise types.
3507935090 // See the comments above for more information.
3508035091 awaitedTypeStack.push(type.id);
35081- const awaitedType = getAwaitedType (promisedType, errorNode, diagnosticMessage, arg0);
35092+ const awaitedType = getAwaitedTypeNoAlias (promisedType, errorNode, diagnosticMessage, arg0);
3508235093 awaitedTypeStack.pop();
3508335094
3508435095 if (!awaitedType) {
3508535096 return undefined;
3508635097 }
3508735098
35088- return createAwaitedTypeIfNeeded( typeAsAwaitable.awaitedTypeOfType = awaitedType) ;
35099+ return typeAsAwaitable.awaitedTypeOfType = awaitedType;
3508935100 }
3509035101
3509135102 // The type was not a promise, so it could not be unwrapped any further.
@@ -35111,7 +35122,7 @@ namespace ts {
3511135122 return undefined;
3511235123 }
3511335124
35114- return createAwaitedTypeIfNeeded( typeAsAwaitable.awaitedTypeOfType = type) ;
35125+ return typeAsAwaitable.awaitedTypeOfType = type;
3511535126 }
3511635127
3511735128 /**
@@ -35161,7 +35172,7 @@ namespace ts {
3516135172 if (globalPromiseType !== emptyGenericType && !isReferenceToType(returnType, globalPromiseType)) {
3516235173 // The promise type was not a valid type reference to the global promise type, so we
3516335174 // report an error and return the unknown type.
35164- error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(unwrapAwaitedType(getAwaitedType( returnType) ) || voidType));
35175+ error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(getAwaitedTypeNoAlias( returnType) || voidType));
3516535176 return;
3516635177 }
3516735178 }
@@ -35214,7 +35225,7 @@ namespace ts {
3521435225 return;
3521535226 }
3521635227 }
35217- checkAwaitedType(returnType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
35228+ checkAwaitedType(returnType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
3521835229 }
3521935230
3522035231 /** Check a decorator */
@@ -37495,7 +37506,7 @@ namespace ts {
3749537506 const isGenerator = !!(functionFlags & FunctionFlags.Generator);
3749637507 const isAsync = !!(functionFlags & FunctionFlags.Async);
3749737508 return isGenerator ? getIterationTypeOfGeneratorFunctionReturnType(IterationTypeKind.Return, returnType, isAsync) || errorType :
37498- isAsync ? unwrapAwaitedType(getAwaitedType( returnType) ) || errorType :
37509+ isAsync ? getAwaitedTypeNoAlias( returnType) || errorType :
3749937510 returnType;
3750037511 }
3750137512
@@ -37539,7 +37550,7 @@ namespace ts {
3753937550 else if (getReturnTypeFromAnnotation(container)) {
3754037551 const unwrappedReturnType = unwrapReturnType(returnType, functionFlags) ?? returnType;
3754137552 const unwrappedExprType = functionFlags & FunctionFlags.Async
37542- ? checkAwaitedType(exprType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
37553+ ? checkAwaitedType(exprType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
3754337554 : exprType;
3754437555 if (unwrappedReturnType) {
3754537556 // If the function has a return type, but promisedType is
0 commit comments