11package org.basedsoft.plugins.basedtyping
22
3+ import com.intellij.openapi.util.Key
34import com.intellij.openapi.util.Ref
45import com.intellij.psi.PsiElement
6+ import com.intellij.psi.impl.source.resolve.FileContextUtil
57import com.intellij.psi.util.siblings
68import com.jetbrains.python.PyTokenTypes
79import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
@@ -67,6 +69,7 @@ fun getType(expression: PyExpression, context: TypeEvalContext, simple: Boolean
6769 return getLiteralType(expression, context)
6870 ? : getUnionType(expression, context)
6971 ? : getTupleType(expression, context)
72+ ? : getStringType(expression, context)
7073 ? : if (simple) null else PyTypingTypeProvider .getType(expression, context)?.get()
7174}
7275
@@ -100,6 +103,25 @@ fun getTupleType(target: PyExpression, context: TypeEvalContext): PyType? {
100103 }.children.map { getType(it as PyExpression , context) }.let { PyTupleType .create(target, it) }
101104}
102105
106+ fun getStringType (target : PyExpression , context : TypeEvalContext ): PyType ? {
107+ if (target !is PyStringLiteralExpression ) return null
108+ val value = target.stringValue.trim()
109+ if (! ((value.startsWith(" def (" ) || value.startsWith(" (" )) && " ->" in value)) return null
110+ val (l, r) = value.split(" ->" )
111+ val argsExpressionPart1 = toExpression(l.removePrefix(" def" ).trim(), target)
112+ val argsExpression = when (argsExpressionPart1) {
113+ is PyTupleExpression -> argsExpressionPart1
114+ is PyParenthesizedExpression -> argsExpressionPart1.containedExpression!!
115+ else -> return null
116+ }
117+ val args = when (argsExpression) {
118+ is PyTupleExpression -> argsExpression.map { getType(it, context) }
119+ else -> listOf (getType(argsExpression, context))
120+ }
121+ val returnType = getType(toExpression(r.trim(), target) ? : return null , context)
122+ return PyCallableTypeImpl (args.map { PyCallableParameterImpl .nonPsi(it) }, returnType)
123+ }
124+
103125fun PyFunction.collectOverloads () =
104126 siblings(forward = false )
105127 .filterIsInstance(PyFunction ::class .java)
@@ -129,4 +151,16 @@ fun getOverloadReturn(func: PyFunction, context: TypeEvalContext) =
129151 .takeIf { it.isNotEmpty() }
130152 ?.let { PyUnionType .union(it) }
131153
154+ val FRAGMENT_OWNER : Key <PsiElement > = Key .create(" PY_FRAGMENT_OWNER" )
155+
156+ fun toExpression (contents : String , anchor : PsiElement ): PyExpression ? {
157+ val file = FileContextUtil .getContextFile(anchor)
158+ if (file == null ) return null
159+ val fragment = PyUtil .createExpressionFromFragment(contents, file)
160+ if (fragment != null ) {
161+ fragment.getContainingFile().putUserData(FRAGMENT_OWNER , anchor)
162+ }
163+ return fragment
164+ }
165+
132166fun PyType?.ref () = this ?.let { Ref .create(it) }
0 commit comments