@@ -12,36 +12,33 @@ import org.jetbrains.kotlin.fir.types.ConeKotlinType
12
12
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
13
13
import org.jetbrains.kotlin.fir.types.classId
14
14
import org.jetbrains.kotlin.fir.types.resolvedType
15
+ import org.jetbrains.kotlin.name.ClassId
15
16
import org.jetbrains.kotlin.name.Name
16
17
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
17
18
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
18
- import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
19
- import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
20
- import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
21
- import org.jetbrains.kotlinx.dataframe.plugin.impl.api.createPluginDataFrameSchema
22
- import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
23
- import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.GROUP_BY_CLASS_ID
24
19
25
- fun KotlinTypeFacade.analyzeRefinedCallShape (call : FirFunctionCall , reporter : InterpretationErrorReporter ): CallResult ? {
20
+ internal inline fun <reified T > KotlinTypeFacade.analyzeRefinedCallShape (
21
+ call : FirFunctionCall ,
22
+ expectedReturnType : ClassId ,
23
+ reporter : InterpretationErrorReporter
24
+ ): CallResult <T >? {
26
25
val callReturnType = call.resolvedType
27
- if (callReturnType.classId != DF_CLASS_ID ) return null
28
- val rootMarker = callReturnType.typeArguments[0 ]
26
+ if (callReturnType.classId != expectedReturnType) return null
29
27
// rootMarker is expected to be a token generated by the plugin.
30
28
// it's implied by "refined call"
31
29
// thus ConeClassLikeType
32
- if (rootMarker !is ConeClassLikeType ) {
33
- return null
34
- }
30
+ val rootMarkers = callReturnType.typeArguments.filterIsInstance<ConeClassLikeType >()
31
+ if (rootMarkers.size != callReturnType.typeArguments.size) return null
35
32
36
- val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name ->
33
+ val newSchema: T = call.interpreterName(session)?.let { name ->
37
34
when (name) {
38
35
else -> name.load<Interpreter <* >>().let { processor ->
39
36
val dataFrameSchema = interpret(call, processor, reporter = reporter)
40
37
.let {
41
38
val value = it?.value
42
- if (value !is PluginDataFrameSchema ) {
39
+ if (value !is T ) {
43
40
if (! reporter.errorReported) {
44
- reporter.reportInterpretationError(call, " ${processor::class } must return ${PluginDataFrameSchema ::class } , but was ${ value} " )
41
+ reporter.reportInterpretationError(call, " ${processor::class } must return ${T ::class } , but was $value " )
45
42
}
46
43
return null
47
44
}
@@ -52,57 +49,10 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In
52
49
}
53
50
} ? : return null
54
51
55
- return CallResult (rootMarker , newSchema)
52
+ return CallResult (rootMarkers , newSchema)
56
53
}
57
54
58
- fun KotlinTypeFacade.analyzeRefinedGroupByCallShape (call : FirFunctionCall , reporter : InterpretationErrorReporter ): GroupByCallResult ? {
59
- val callReturnType = call.resolvedType
60
- if (callReturnType.classId != GROUP_BY_CLASS_ID ) return null
61
- val keyMarker = callReturnType.typeArguments[0 ]
62
- val groupMarker = callReturnType.typeArguments[1 ]
63
- // rootMarker is expected to be a token generated by the plugin.
64
- // it's implied by "refined call"
65
- // thus ConeClassLikeType
66
- if (keyMarker !is ConeClassLikeType || groupMarker !is ConeClassLikeType ) {
67
- return null
68
- }
69
-
70
- val newSchema = call.interpreterName(session)?.let { name ->
71
- name.load<Interpreter <* >>().let { processor ->
72
- val dataFrameSchema = interpret(call, processor, reporter = reporter)
73
- .let {
74
- val value = it?.value
75
- if (value !is GroupBy ) {
76
- if (! reporter.errorReported) {
77
- reporter.reportInterpretationError(call, " ${processor::class } must return ${PluginDataFrameSchema ::class } , but was ${value} " )
78
- }
79
- return null
80
- }
81
- value
82
- }
83
-
84
- val keySchema = createPluginDataFrameSchema(dataFrameSchema.keys, dataFrameSchema.moveToTop)
85
- val groupSchema = PluginDataFrameSchema (dataFrameSchema.df.columns())
86
- GroupBySchema (keySchema, groupSchema)
87
- }
88
- } ? : return null
89
-
90
- return GroupByCallResult (keyMarker, newSchema.keySchema, groupMarker, newSchema.groupSchema)
91
- }
92
-
93
- data class GroupByCallResult (
94
- val keyMarker : ConeClassLikeType ,
95
- val keySchema : PluginDataFrameSchema ,
96
- val groupMarker : ConeClassLikeType ,
97
- val groupSchema : PluginDataFrameSchema ,
98
- )
99
-
100
- data class GroupBySchema (
101
- val keySchema : PluginDataFrameSchema ,
102
- val groupSchema : PluginDataFrameSchema ,
103
- )
104
-
105
- data class CallResult (val rootMarker : ConeClassLikeType , val newSchema : PluginDataFrameSchema )
55
+ data class CallResult <T >(val markers : List <ConeClassLikeType >, val result : T )
106
56
107
57
class RefinedArguments (val refinedArguments : List <RefinedArgument >) : List<RefinedArgument> by refinedArguments
108
58
0 commit comments