Skip to content

Commit 6b628f8

Browse files
committed
[Compiler plugin] Refactor type refinement pipeline to work for multiple types without much code duplication
1 parent 98c0ae8 commit 6b628f8

File tree

4 files changed

+61
-128
lines changed

4 files changed

+61
-128
lines changed

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt

Lines changed: 14 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,33 @@ import org.jetbrains.kotlin.fir.types.ConeKotlinType
1212
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
1313
import org.jetbrains.kotlin.fir.types.classId
1414
import org.jetbrains.kotlin.fir.types.resolvedType
15+
import org.jetbrains.kotlin.name.ClassId
1516
import org.jetbrains.kotlin.name.Name
1617
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
1718
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
18-
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
19-
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
20-
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
21-
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.createPluginDataFrameSchema
22-
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
23-
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.GROUP_BY_CLASS_ID
2419

25-
fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): CallResult? {
20+
internal inline fun <reified T> KotlinTypeFacade.analyzeRefinedCallShape(
21+
call: FirFunctionCall,
22+
expectedReturnType: ClassId,
23+
reporter: InterpretationErrorReporter
24+
): CallResult<T>? {
2625
val callReturnType = call.resolvedType
27-
if (callReturnType.classId != DF_CLASS_ID) return null
28-
val rootMarker = callReturnType.typeArguments[0]
26+
if (callReturnType.classId != expectedReturnType) return null
2927
// rootMarker is expected to be a token generated by the plugin.
3028
// it's implied by "refined call"
3129
// thus ConeClassLikeType
32-
if (rootMarker !is ConeClassLikeType) {
33-
return null
34-
}
30+
val rootMarkers = callReturnType.typeArguments.filterIsInstance<ConeClassLikeType>()
31+
if (rootMarkers.size != callReturnType.typeArguments.size) return null
3532

36-
val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name ->
33+
val newSchema: T = call.interpreterName(session)?.let { name ->
3734
when (name) {
3835
else -> name.load<Interpreter<*>>().let { processor ->
3936
val dataFrameSchema = interpret(call, processor, reporter = reporter)
4037
.let {
4138
val value = it?.value
42-
if (value !is PluginDataFrameSchema) {
39+
if (value !is T) {
4340
if (!reporter.errorReported) {
44-
reporter.reportInterpretationError(call, "${processor::class} must return ${PluginDataFrameSchema::class}, but was ${value}")
41+
reporter.reportInterpretationError(call, "${processor::class} must return ${T::class}, but was $value")
4542
}
4643
return null
4744
}
@@ -52,57 +49,10 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In
5249
}
5350
} ?: return null
5451

55-
return CallResult(rootMarker, newSchema)
52+
return CallResult(rootMarkers, newSchema)
5653
}
5754

58-
fun KotlinTypeFacade.analyzeRefinedGroupByCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): GroupByCallResult? {
59-
val callReturnType = call.resolvedType
60-
if (callReturnType.classId != GROUP_BY_CLASS_ID) return null
61-
val keyMarker = callReturnType.typeArguments[0]
62-
val groupMarker = callReturnType.typeArguments[1]
63-
// rootMarker is expected to be a token generated by the plugin.
64-
// it's implied by "refined call"
65-
// thus ConeClassLikeType
66-
if (keyMarker !is ConeClassLikeType || groupMarker !is ConeClassLikeType) {
67-
return null
68-
}
69-
70-
val newSchema = call.interpreterName(session)?.let { name ->
71-
name.load<Interpreter<*>>().let { processor ->
72-
val dataFrameSchema = interpret(call, processor, reporter = reporter)
73-
.let {
74-
val value = it?.value
75-
if (value !is GroupBy) {
76-
if (!reporter.errorReported) {
77-
reporter.reportInterpretationError(call, "${processor::class} must return ${PluginDataFrameSchema::class}, but was ${value}")
78-
}
79-
return null
80-
}
81-
value
82-
}
83-
84-
val keySchema = createPluginDataFrameSchema(dataFrameSchema.keys, dataFrameSchema.moveToTop)
85-
val groupSchema = PluginDataFrameSchema(dataFrameSchema.df.columns())
86-
GroupBySchema(keySchema, groupSchema)
87-
}
88-
} ?: return null
89-
90-
return GroupByCallResult(keyMarker, newSchema.keySchema, groupMarker, newSchema.groupSchema)
91-
}
92-
93-
data class GroupByCallResult(
94-
val keyMarker: ConeClassLikeType,
95-
val keySchema: PluginDataFrameSchema,
96-
val groupMarker: ConeClassLikeType,
97-
val groupSchema: PluginDataFrameSchema,
98-
)
99-
100-
data class GroupBySchema(
101-
val keySchema: PluginDataFrameSchema,
102-
val groupSchema: PluginDataFrameSchema,
103-
)
104-
105-
data class CallResult(val rootMarker: ConeClassLikeType, val newSchema: PluginDataFrameSchema)
55+
data class CallResult<T>(val markers: List<ConeClassLikeType>, val result: T)
10656

10757
class RefinedArguments(val refinedArguments: List<RefinedArgument>) : List<RefinedArgument> by refinedArguments
10858

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,13 @@ import org.jetbrains.kotlin.name.FqName
7474
import org.jetbrains.kotlin.name.Name
7575
import org.jetbrains.kotlin.text
7676
import org.jetbrains.kotlin.types.Variance
77-
import org.jetbrains.kotlinx.dataframe.plugin.analyzeRefinedGroupByCallShape
7877
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
7978
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
8079
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
8180
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
8281
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
82+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
83+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.createPluginDataFrameSchema
8384
import kotlin.math.abs
8485

8586
@OptIn(FirExtensionApiInternals::class)
@@ -222,14 +223,10 @@ class FunctionCallTransformer(
222223

223224
@OptIn(SymbolInternals::class)
224225
override fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall? {
225-
val analyzeRefinedCallShape = analyzeRefinedCallShape(call, InterpretationErrorReporter.DEFAULT)
226-
227-
val (token, dataFrameSchema) =
228-
analyzeRefinedCallShape ?: return null
229-
230-
226+
val callResult = analyzeRefinedCallShape<PluginDataFrameSchema>(call, Names.DF_CLASS_ID, InterpretationErrorReporter.DEFAULT)
227+
val (tokens, dataFrameSchema) = callResult ?: return null
228+
val token = tokens[0]
231229
val firstSchema = token.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!!
232-
233230
val dataSchemaApis = materialize(dataFrameSchema, call, firstSchema)
234231

235232
val tokenFir = token.toClassSymbol(session)!!.fir
@@ -272,23 +269,28 @@ class FunctionCallTransformer(
272269

273270
@OptIn(SymbolInternals::class)
274271
override fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall? {
275-
val (keyMarker, keySchema, groupMarker, groupSchema) = analyzeRefinedGroupByCallShape(
276-
call,
277-
InterpretationErrorReporter.DEFAULT
278-
) ?: return null
272+
val callResult = analyzeRefinedCallShape<GroupBy>(call, Names.GROUP_BY_CLASS_ID, InterpretationErrorReporter.DEFAULT)
273+
val (rootMarkers, groupBy) = callResult ?: return null
274+
275+
val keyMarker = rootMarkers[0]
276+
val groupMarker = rootMarkers[1]
277+
278+
val keySchema = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop)
279+
val groupSchema = PluginDataFrameSchema(groupBy.df.columns())
280+
279281
val firstSchema = keyMarker.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!!
280282
val firstSchema1 = groupMarker.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!!
281283

282-
val dataSchemaApis = materialize(keySchema, call, firstSchema, "Key")
283-
val dataSchemaApis1 = materialize(groupSchema, call, firstSchema1, "Group", dataSchemaApis.size)
284+
val keyApis = materialize(keySchema, call, firstSchema, "Key")
285+
val groupApis = materialize(groupSchema, call, firstSchema1, "Group", i = keyApis.size)
284286

285-
val tokenFir = keyMarker.toClassSymbol(session)!!.fir
286-
tokenFir.callShapeData = CallShapeData.RefinedType(dataSchemaApis.map { it.scope.symbol })
287+
val groupToken = keyMarker.toClassSymbol(session)!!.fir
288+
groupToken.callShapeData = CallShapeData.RefinedType(keyApis.map { it.scope.symbol })
287289

288-
val tokenFir1 = groupMarker.toClassSymbol(session)!!.fir
289-
tokenFir1.callShapeData = CallShapeData.RefinedType(dataSchemaApis1.map { it.scope.symbol })
290+
val keyToken = groupMarker.toClassSymbol(session)!!.fir
291+
keyToken.callShapeData = CallShapeData.RefinedType(groupApis.map { it.scope.symbol })
290292

291-
return buildLetCall(call, originalSymbol, dataSchemaApis + dataSchemaApis1, listOf(tokenFir, tokenFir1))
293+
return buildLetCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken))
292294
}
293295
}
294296

Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package org.jetbrains.kotlinx.dataframe.plugin.extensions
22

33
import org.jetbrains.kotlin.fir.FirSession
4-
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
54
import org.jetbrains.kotlin.fir.declarations.FirResolvePhase
65
import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId
76
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
@@ -10,49 +9,33 @@ import org.jetbrains.kotlin.fir.scopes.collectAllProperties
109
import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope
1110
import org.jetbrains.kotlin.fir.symbols.SymbolInternals
1211
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
13-
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
14-
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
1512
import org.jetbrains.kotlin.fir.types.ConeKotlinType
1613
import org.jetbrains.kotlin.fir.types.classId
1714
import org.jetbrains.kotlin.fir.types.resolvedType
1815
import org.jetbrains.kotlin.fir.types.toRegularClassSymbol
16+
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
1917

2018
class ReturnTypeBasedReceiverInjector(session: FirSession) : FirExpressionResolutionExtension(session) {
19+
@OptIn(SymbolInternals::class)
2120
override fun addNewImplicitReceivers(functionCall: FirFunctionCall): List<ConeKotlinType> {
2221
val callReturnType = functionCall.resolvedType
23-
if (callReturnType.classId == Names.GROUP_BY_CLASS_ID) {
24-
val rootMarker = callReturnType.typeArguments[0]
25-
val rootMarker1 = callReturnType.typeArguments[1]
26-
if (rootMarker !is ConeClassLikeType || rootMarker1 !is ConeClassLikeType) {
27-
return emptyList()
28-
}
29-
val symbol = rootMarker.toRegularClassSymbol(session)
30-
val symbol1 = rootMarker1.toRegularClassSymbol(session)
31-
32-
return listOfNotNull(symbol, symbol1).flatMap {
33-
it.declaredMemberScope(session, FirResolvePhase.DECLARATIONS).collectAllProperties()
34-
.filterIsInstance<FirPropertySymbol>()
35-
.filter { it.getAnnotationByClassId(Names.SCOPE_PROPERTY_ANNOTATION, session) != null }
36-
.map { it.resolvedReturnType }
37-
}
22+
return if (callReturnType.classId in setOf(Names.DF_CLASS_ID, Names.GROUP_BY_CLASS_ID)) {
23+
val typeArguments = callReturnType.typeArguments
24+
typeArguments
25+
.mapNotNull {
26+
val symbol = (it as? ConeKotlinType)?.toRegularClassSymbol(session)
27+
symbol?.takeIf { it.fir.callShapeData != null }
28+
}
29+
.takeIf { it.size == typeArguments.size }
30+
.orEmpty()
31+
.flatMap { marker ->
32+
marker.declaredMemberScope(session, FirResolvePhase.DECLARATIONS).collectAllProperties()
33+
.filterIsInstance<FirPropertySymbol>()
34+
.filter { it.getAnnotationByClassId(Names.SCOPE_PROPERTY_ANNOTATION, session) != null }
35+
.map { it.resolvedReturnType }
36+
}
37+
} else {
38+
emptyList()
3839
}
39-
val symbol = generatedTokenOrNull(functionCall) ?: return emptyList()
40-
return symbol.declaredMemberScope(session, FirResolvePhase.DECLARATIONS).collectAllProperties()
41-
.filterIsInstance<FirPropertySymbol>()
42-
.filter { it.getAnnotationByClassId(Names.SCOPE_PROPERTY_ANNOTATION, session) != null }
43-
.map { it.resolvedReturnType }
44-
}
45-
46-
@OptIn(SymbolInternals::class)
47-
private fun generatedTokenOrNull(call: FirFunctionCall): FirRegularClassSymbol? {
48-
val callReturnType = call.resolvedType
49-
if (callReturnType.classId != Names.DF_CLASS_ID) return null
50-
val rootMarker = callReturnType.typeArguments[0]
51-
if (rootMarker !is ConeClassLikeType) {
52-
return null
53-
}
54-
55-
val symbol = rootMarker.toRegularClassSymbol(session)
56-
return symbol.takeIf { it?.fir?.callShapeData != null }
5740
}
5841
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,19 +145,17 @@ fun <T> KotlinTypeFacade.interpret(
145145
}
146146

147147
is FirFunctionCall -> {
148-
var interpreter = expression.loadInterpreter()
148+
val interpreter = expression.loadInterpreter()
149149
if (interpreter == null) {
150-
val r = expression.arguments[0]
151-
val last = (r as? FirAnonymousFunctionExpression)?.anonymousFunction?.body?.statements?.lastOrNull()
150+
// if the plugin already transformed call, its original form is the last expression of .let { }
151+
val argument = expression.arguments[0]
152+
val last = (argument as? FirAnonymousFunctionExpression)?.anonymousFunction?.body?.statements?.lastOrNull()
152153
val call = (last as? FirReturnExpression)?.result as? FirFunctionCall
153-
val interpreter = call?.loadInterpreter()
154-
if (interpreter != null) {
155-
interpret(call, interpreter, emptyMap(), reporter)
156-
} else {
157-
null
154+
call?.loadInterpreter()?.let {
155+
interpret(call, it, emptyMap(), reporter)
158156
}
159157
} else {
160-
interpreter?.let {
158+
interpreter.let {
161159
val result = interpret(expression, interpreter, emptyMap(), reporter)
162160
result
163161
}

0 commit comments

Comments
 (0)