Skip to content

Commit 98c0ae8

Browse files
committed
[Compiler plugin] Remove code duplication caused by aggregate implementation
1 parent 233cda0 commit 98c0ae8

File tree

5 files changed

+42
-30
lines changed

5 files changed

+42
-30
lines changed

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,6 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In
3535

3636
val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name ->
3737
when (name) {
38-
"Aggregate" -> {
39-
val groupByCall = call.explicitReceiver as? FirFunctionCall
40-
val interpreter = groupByCall?.loadInterpreter(session)
41-
if (interpreter != null) {
42-
aggregate(groupByCall, interpreter, reporter, call)
43-
} else {
44-
PluginDataFrameSchema(emptyList())
45-
}
46-
}
4738
else -> name.load<Interpreter<*>>().let { processor ->
4839
val dataFrameSchema = interpret(call, processor, reporter = reporter)
4940
.let {

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ class FunctionCallTransformer(
279279
val firstSchema = keyMarker.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!!
280280
val firstSchema1 = groupMarker.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!!
281281

282-
val dataSchemaApis = materialize(keySchema, call, firstSchema)
283-
val dataSchemaApis1 = materialize(groupSchema, call, firstSchema1, dataSchemaApis.size)
282+
val dataSchemaApis = materialize(keySchema, call, firstSchema, "Key")
283+
val dataSchemaApis1 = materialize(groupSchema, call, firstSchema1, "Group", dataSchemaApis.size)
284284

285285
val tokenFir = keyMarker.toClassSymbol(session)!!.fir
286286
tokenFir.callShapeData = CallShapeData.RefinedType(dataSchemaApis.map { it.scope.symbol })
@@ -431,6 +431,7 @@ class FunctionCallTransformer(
431431
dataFrameSchema: PluginDataFrameSchema,
432432
call: FirFunctionCall,
433433
firstSchema: FirRegularClass,
434+
prefix: String = "",
434435
i: Int = 0
435436
): List<DataSchemaApi> {
436437
var i = i
@@ -469,7 +470,7 @@ class FunctionCallTransformer(
469470
val text = call.source?.text ?: call.calleeReference.name
470471
val name =
471472
"${column.name.titleCase().replEscapeLineBreaks()}_${hashToTwoCharString(abs(text.hashCode()))}"
472-
return materialize(suggestedName = name)
473+
return materialize(suggestedName = "$prefix$name")
473474
}
474475

475476
when (it) {

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,11 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api
33
import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter
44
import org.jetbrains.kotlinx.dataframe.plugin.interpret
55
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter
6-
import org.jetbrains.kotlinx.dataframe.plugin.pluginDataFrameSchema
7-
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
86
import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression
97
import org.jetbrains.kotlin.fir.expressions.FirExpression
108
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
119
import org.jetbrains.kotlin.fir.expressions.FirReturnExpression
1210
import org.jetbrains.kotlin.fir.types.ConeKotlinType
13-
import org.jetbrains.kotlin.fir.types.ConeNullability
14-
import org.jetbrains.kotlin.fir.types.classId
1511
import org.jetbrains.kotlin.fir.types.resolvedType
1612
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
1713
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
@@ -55,21 +51,30 @@ class GroupByInto : AbstractInterpreter<Unit>() {
5551
}
5652
}
5753

54+
class Aggregate : AbstractSchemaModificationInterpreter() {
55+
val Arguments.receiver: GroupBy by arg()
56+
val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id)
57+
override fun Arguments.interpret(): PluginDataFrameSchema {
58+
return aggregate(
59+
receiver,
60+
InterpretationErrorReporter.DEFAULT,
61+
body
62+
)
63+
}
64+
}
65+
5866
fun KotlinTypeFacade.aggregate(
59-
groupByCall: FirFunctionCall,
60-
interpreter: Interpreter<*>,
67+
groupBy: GroupBy,
6168
reporter: InterpretationErrorReporter,
62-
call: FirFunctionCall
63-
): PluginDataFrameSchema? {
64-
val groupBy = interpret(groupByCall, interpreter, reporter = reporter)?.value as? GroupBy ?: return null
65-
val aggregate = call.argumentList.arguments.singleOrNull() as? FirAnonymousFunctionExpression
66-
val body = aggregate?.anonymousFunction?.body ?: return null
67-
val lastExpression = (body.statements.lastOrNull() as? FirReturnExpression)?.result
69+
firAnonymousFunctionExpression: FirAnonymousFunctionExpression
70+
): PluginDataFrameSchema {
71+
val body = firAnonymousFunctionExpression.anonymousFunction.body
72+
val lastExpression = (body?.statements?.lastOrNull() as? FirReturnExpression)?.result
6873
val type = lastExpression?.resolvedType
6974
return if (type != session.builtinTypes.unitType) {
7075
val dsl = GroupByDsl()
7176
val calls = buildList {
72-
body.statements.filterIsInstance<FirFunctionCall>().let { addAll(it) }
77+
body?.statements?.filterIsInstance<FirFunctionCall>()?.let { addAll(it) }
7378
if (lastExpression is FirFunctionCall) add(lastExpression)
7479
}
7580
calls.forEach { call ->
@@ -87,7 +92,7 @@ fun KotlinTypeFacade.aggregate(
8792
}
8893
PluginDataFrameSchema(cols)
8994
} else {
90-
null
95+
PluginDataFrameSchema(emptyList())
9196
}
9297
}
9398

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier
2828
import org.jetbrains.kotlin.fir.expressions.FirReturnExpression
2929
import org.jetbrains.kotlin.fir.expressions.FirThisReceiverExpression
3030
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
31+
import org.jetbrains.kotlin.fir.expressions.arguments
3132
import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList
3233
import org.jetbrains.kotlin.fir.references.FirResolvedCallableReference
3334
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
@@ -144,10 +145,22 @@ fun <T> KotlinTypeFacade.interpret(
144145
}
145146

146147
is FirFunctionCall -> {
147-
val interpreter = expression.loadInterpreter()
148-
interpreter?.let {
149-
val result = interpret(expression, interpreter, emptyMap(), reporter)
150-
result
148+
var interpreter = expression.loadInterpreter()
149+
if (interpreter == null) {
150+
val r = expression.arguments[0]
151+
val last = (r as? FirAnonymousFunctionExpression)?.anonymousFunction?.body?.statements?.lastOrNull()
152+
val call = (last as? FirReturnExpression)?.result as? FirFunctionCall
153+
val interpreter = call?.loadInterpreter()
154+
if (interpreter != null) {
155+
interpret(call, interpreter, emptyMap(), reporter)
156+
} else {
157+
null
158+
}
159+
} else {
160+
interpreter?.let {
161+
val result = interpret(expression, interpreter, emptyMap(), reporter)
162+
result
163+
}
151164
}
152165
}
153166

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ import org.jetbrains.kotlin.name.ClassId
6969
import org.jetbrains.kotlin.name.Name
7070
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslStringInvoke
7171
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddId
72+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Aggregate
7273
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.All0
7374
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsAtAnyDepth0
7475
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
@@ -206,6 +207,7 @@ internal inline fun <reified T> String.load(): T {
206207
"Move0" -> Move0()
207208
"ToTop" -> ToTop()
208209
"Update0" -> Update0()
210+
"Aggregate" -> Aggregate()
209211
else -> error("$this")
210212
} as T
211213
}

0 commit comments

Comments
 (0)