Skip to content

Commit 28860fa

Browse files
authored
Fix bugs around async functions (#22)
* Fix invalid display of coroutine which return another coroutine * Fix invalid display of AsyncGenerator
1 parent 17d07db commit 28860fa

File tree

2 files changed

+9
-27
lines changed

2 files changed

+9
-27
lines changed

src/main/kotlin/space/whitememory/pythoninlayparams/types/hints/HintGenerator.kt

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,6 @@ enum class HintGenerator {
4848
}
4949
},
5050

51-
ASYNC_TYPE {
52-
override fun handleType(
53-
element: PyElement,
54-
type: PyType?,
55-
typeEvalContext: TypeEvalContext
56-
): List<InlayInfoDetails>? {
57-
if (type == null || element !is PyFunction) return null
58-
59-
if (type is PyCollectionType && type.classQName == PyTypingTypeProvider.COROUTINE && element.isAsync) {
60-
return generateTypeHintText(
61-
element, PyTypingTypeProvider.coroutineOrGeneratorElementType(type)?.get(), typeEvalContext
62-
)
63-
}
64-
65-
return null
66-
}
67-
},
68-
6951
COLLECTION_TYPE {
7052
override fun handleType(
7153
element: PyElement,

src/main/kotlin/space/whitememory/pythoninlayparams/types/hints/HintResolver.kt

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ enum class HintResolver {
6363

6464
// Handle case `var = async_func()` without `await` keyword
6565
if (assignedValue is PyCallExpression) return true
66+
// Handle case 'var = await async_func()` which return `Coroutine` inside
67+
if (assignedValue is PyPrefixExpression && assignedValue.operator == PyTokenTypes.AWAIT_KEYWORD) return true
6668

6769
if (typeAnnotation is PyClassType && isElementInsideTypingModule(typeAnnotation.pyClass)) return false
6870

@@ -398,7 +400,13 @@ enum class HintResolver {
398400
}
399401

400402
fun getExpressionAnnotationType(element: PyElement, typeEvalContext: TypeEvalContext): PyType? {
401-
if (element is PyFunction) return typeEvalContext.getReturnType(element)
403+
if (element is PyFunction) {
404+
if (element.isAsync && !element.isGenerator) {
405+
return element.getReturnStatementType(typeEvalContext)
406+
}
407+
408+
return typeEvalContext.getReturnType(element)
409+
}
402410
if (element is PyTargetExpression) return typeEvalContext.getType(element)
403411

404412
return null
@@ -428,14 +436,6 @@ enum class HintResolver {
428436
}
429437
}
430438

431-
if (element is PyFunction && element.isAsync) {
432-
val functionType = PyTypingTypeProvider.coroutineOrGeneratorElementType(typeAnnotation)?.get()
433-
434-
if (functionType is PyNoneType || PyTypeChecker.isUnknown(functionType, false, typeEvalContext)) {
435-
return false
436-
}
437-
}
438-
439439
if (PyTypeChecker.isUnknown(typeAnnotation, false, typeEvalContext)) return false
440440

441441
return true

0 commit comments

Comments
 (0)