Skip to content

Commit 233cda0

Browse files
committed
[Compiler plugin] Support transformation of functions with GroupBy type
1 parent 83296fd commit 233cda0

File tree

7 files changed

+313
-114
lines changed

7 files changed

+313
-114
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import kotlin.reflect.KProperty
2929
*
3030
* `df.add("columnName") { "someColumn"<Int>() + 15 }.groupBy("columnName")`
3131
*/
32+
@Refine
3233
@Interpretable("DataFrameGroupBy")
3334
public fun <T> DataFrame<T>.groupBy(moveToTop: Boolean = true, cols: ColumnsSelector<T, *>): GroupBy<T, T> =
3435
groupByImpl(moveToTop, cols)

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ import org.jetbrains.kotlin.name.Name
1616
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
1717
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
1818
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
19+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
1920
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
21+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.createPluginDataFrameSchema
2022
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
23+
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.GROUP_BY_CLASS_ID
2124

2225
fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): CallResult? {
2326
val callReturnType = call.resolvedType
@@ -61,6 +64,53 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In
6164
return CallResult(rootMarker, newSchema)
6265
}
6366

67+
fun KotlinTypeFacade.analyzeRefinedGroupByCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): GroupByCallResult? {
68+
val callReturnType = call.resolvedType
69+
if (callReturnType.classId != GROUP_BY_CLASS_ID) return null
70+
val keyMarker = callReturnType.typeArguments[0]
71+
val groupMarker = callReturnType.typeArguments[1]
72+
// rootMarker is expected to be a token generated by the plugin.
73+
// it's implied by "refined call"
74+
// thus ConeClassLikeType
75+
if (keyMarker !is ConeClassLikeType || groupMarker !is ConeClassLikeType) {
76+
return null
77+
}
78+
79+
val newSchema = call.interpreterName(session)?.let { name ->
80+
name.load<Interpreter<*>>().let { processor ->
81+
val dataFrameSchema = interpret(call, processor, reporter = reporter)
82+
.let {
83+
val value = it?.value
84+
if (value !is GroupBy) {
85+
if (!reporter.errorReported) {
86+
reporter.reportInterpretationError(call, "${processor::class} must return ${PluginDataFrameSchema::class}, but was ${value}")
87+
}
88+
return null
89+
}
90+
value
91+
}
92+
93+
val keySchema = createPluginDataFrameSchema(dataFrameSchema.keys, dataFrameSchema.moveToTop)
94+
val groupSchema = PluginDataFrameSchema(dataFrameSchema.df.columns())
95+
GroupBySchema(keySchema, groupSchema)
96+
}
97+
} ?: return null
98+
99+
return GroupByCallResult(keyMarker, newSchema.keySchema, groupMarker, newSchema.groupSchema)
100+
}
101+
102+
data class GroupByCallResult(
103+
val keyMarker: ConeClassLikeType,
104+
val keySchema: PluginDataFrameSchema,
105+
val groupMarker: ConeClassLikeType,
106+
val groupSchema: PluginDataFrameSchema,
107+
)
108+
109+
data class GroupBySchema(
110+
val keySchema: PluginDataFrameSchema,
111+
val groupSchema: PluginDataFrameSchema,
112+
)
113+
64114
data class CallResult(val rootMarker: ConeClassLikeType, val newSchema: PluginDataFrameSchema)
65115

66116
class RefinedArguments(val refinedArguments: List<RefinedArgument>) : List<RefinedArgument> by refinedArguments

0 commit comments

Comments
 (0)