Skip to content

Commit 934de94

Browse files
committed
intellij: callable literal types
1 parent 24ac38b commit 934de94

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

intellij/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
## [Unreleased]
66

7+
### Added
8+
- Callable literal types `() -> int`
9+
710
## [0.1.2] - 2023-01-12
811

912
### Fixed

intellij/src/main/kotlin/org/basedsoft/plugins/basedtyping/BasedTypingTypeProvider.kt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package org.basedsoft.plugins.basedtyping
22

3+
import com.intellij.openapi.util.Key
34
import com.intellij.openapi.util.Ref
45
import com.intellij.psi.PsiElement
6+
import com.intellij.psi.impl.source.resolve.FileContextUtil
57
import com.intellij.psi.util.siblings
68
import com.jetbrains.python.PyTokenTypes
79
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
@@ -67,6 +69,7 @@ fun getType(expression: PyExpression, context: TypeEvalContext, simple: Boolean
6769
return getLiteralType(expression, context)
6870
?: getUnionType(expression, context)
6971
?: getTupleType(expression, context)
72+
?: getStringType(expression, context)
7073
?: if (simple) null else PyTypingTypeProvider.getType(expression, context)?.get()
7174
}
7275

@@ -100,6 +103,25 @@ fun getTupleType(target: PyExpression, context: TypeEvalContext): PyType? {
100103
}.children.map { getType(it as PyExpression, context) }.let { PyTupleType.create(target, it) }
101104
}
102105

106+
fun getStringType(target: PyExpression, context: TypeEvalContext): PyType? {
107+
if (target !is PyStringLiteralExpression) return null
108+
val value = target.stringValue.trim()
109+
if (!((value.startsWith("def (") || value.startsWith("(")) && "->" in value)) return null
110+
val (l, r) = value.split("->")
111+
val argsExpressionPart1 = toExpression(l.removePrefix("def").trim(), target)
112+
val argsExpression = when (argsExpressionPart1) {
113+
is PyTupleExpression -> argsExpressionPart1
114+
is PyParenthesizedExpression -> argsExpressionPart1.containedExpression!!
115+
else -> return null
116+
}
117+
val args = when (argsExpression) {
118+
is PyTupleExpression -> argsExpression.map { getType(it, context) }
119+
else -> listOf(getType(argsExpression, context))
120+
}
121+
val returnType = getType(toExpression(r.trim(), target) ?: return null, context)
122+
return PyCallableTypeImpl(args.map { PyCallableParameterImpl.nonPsi(it) }, returnType)
123+
}
124+
103125
fun PyFunction.collectOverloads() =
104126
siblings(forward = false)
105127
.filterIsInstance(PyFunction::class.java)
@@ -129,4 +151,16 @@ fun getOverloadReturn(func: PyFunction, context: TypeEvalContext) =
129151
.takeIf { it.isNotEmpty() }
130152
?.let { PyUnionType.union(it) }
131153

154+
val FRAGMENT_OWNER: Key<PsiElement> = Key.create("PY_FRAGMENT_OWNER")
155+
156+
fun toExpression(contents: String, anchor: PsiElement): PyExpression? {
157+
val file = FileContextUtil.getContextFile(anchor)
158+
if (file == null) return null
159+
val fragment = PyUtil.createExpressionFromFragment(contents, file)
160+
if (fragment != null) {
161+
fragment.getContainingFile().putUserData(FRAGMENT_OWNER, anchor)
162+
}
163+
return fragment
164+
}
165+
132166
fun PyType?.ref() = this?.let { Ref.create(it) }

0 commit comments

Comments
 (0)