@@ -16,8 +16,11 @@ import org.jetbrains.kotlin.name.Name
16
16
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
17
17
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
18
18
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
19
+ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
19
20
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
21
+ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.createPluginDataFrameSchema
20
22
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
23
+ import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.GROUP_BY_CLASS_ID
21
24
22
25
fun KotlinTypeFacade.analyzeRefinedCallShape (call : FirFunctionCall , reporter : InterpretationErrorReporter ): CallResult ? {
23
26
val callReturnType = call.resolvedType
@@ -61,6 +64,53 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In
61
64
return CallResult (rootMarker, newSchema)
62
65
}
63
66
67
+ fun KotlinTypeFacade.analyzeRefinedGroupByCallShape (call : FirFunctionCall , reporter : InterpretationErrorReporter ): GroupByCallResult ? {
68
+ val callReturnType = call.resolvedType
69
+ if (callReturnType.classId != GROUP_BY_CLASS_ID ) return null
70
+ val keyMarker = callReturnType.typeArguments[0 ]
71
+ val groupMarker = callReturnType.typeArguments[1 ]
72
+ // rootMarker is expected to be a token generated by the plugin.
73
+ // it's implied by "refined call"
74
+ // thus ConeClassLikeType
75
+ if (keyMarker !is ConeClassLikeType || groupMarker !is ConeClassLikeType ) {
76
+ return null
77
+ }
78
+
79
+ val newSchema = call.interpreterName(session)?.let { name ->
80
+ name.load<Interpreter <* >>().let { processor ->
81
+ val dataFrameSchema = interpret(call, processor, reporter = reporter)
82
+ .let {
83
+ val value = it?.value
84
+ if (value !is GroupBy ) {
85
+ if (! reporter.errorReported) {
86
+ reporter.reportInterpretationError(call, " ${processor::class } must return ${PluginDataFrameSchema ::class } , but was ${value} " )
87
+ }
88
+ return null
89
+ }
90
+ value
91
+ }
92
+
93
+ val keySchema = createPluginDataFrameSchema(dataFrameSchema.keys, dataFrameSchema.moveToTop)
94
+ val groupSchema = PluginDataFrameSchema (dataFrameSchema.df.columns())
95
+ GroupBySchema (keySchema, groupSchema)
96
+ }
97
+ } ? : return null
98
+
99
+ return GroupByCallResult (keyMarker, newSchema.keySchema, groupMarker, newSchema.groupSchema)
100
+ }
101
+
102
+ data class GroupByCallResult (
103
+ val keyMarker : ConeClassLikeType ,
104
+ val keySchema : PluginDataFrameSchema ,
105
+ val groupMarker : ConeClassLikeType ,
106
+ val groupSchema : PluginDataFrameSchema ,
107
+ )
108
+
109
+ data class GroupBySchema (
110
+ val keySchema : PluginDataFrameSchema ,
111
+ val groupSchema : PluginDataFrameSchema ,
112
+ )
113
+
64
114
data class CallResult (val rootMarker : ConeClassLikeType , val newSchema : PluginDataFrameSchema )
65
115
66
116
class RefinedArguments (val refinedArguments : List <RefinedArgument >) : List<RefinedArgument> by refinedArguments
0 commit comments