@@ -63,6 +63,8 @@ enum class HintResolver {
6363
6464 // Handle case `var = async_func()` without `await` keyword
6565 if (assignedValue is PyCallExpression ) return true
66+ // Handle case 'var = await async_func()` which return `Coroutine` inside
67+ if (assignedValue is PyPrefixExpression && assignedValue.operator == PyTokenTypes .AWAIT_KEYWORD ) return true
6668
6769 if (typeAnnotation is PyClassType && isElementInsideTypingModule(typeAnnotation.pyClass)) return false
6870
@@ -398,7 +400,13 @@ enum class HintResolver {
398400 }
399401
400402 fun getExpressionAnnotationType (element : PyElement , typeEvalContext : TypeEvalContext ): PyType ? {
401- if (element is PyFunction ) return typeEvalContext.getReturnType(element)
403+ if (element is PyFunction ) {
404+ if (element.isAsync && ! element.isGenerator) {
405+ return element.getReturnStatementType(typeEvalContext)
406+ }
407+
408+ return typeEvalContext.getReturnType(element)
409+ }
402410 if (element is PyTargetExpression ) return typeEvalContext.getType(element)
403411
404412 return null
@@ -428,14 +436,6 @@ enum class HintResolver {
428436 }
429437 }
430438
431- if (element is PyFunction && element.isAsync) {
432- val functionType = PyTypingTypeProvider .coroutineOrGeneratorElementType(typeAnnotation)?.get()
433-
434- if (functionType is PyNoneType || PyTypeChecker .isUnknown(functionType, false , typeEvalContext)) {
435- return false
436- }
437- }
438-
439439 if (PyTypeChecker .isUnknown(typeAnnotation, false , typeEvalContext)) return false
440440
441441 return true
0 commit comments