Skip to content

Commit 2143dfc

Browse files
committed
[Compiler plugin] Extract GroupBy from updated type instead of recursively calling interpreters
1 parent 22da430 commit 2143dfc

File tree

7 files changed

+51
-9
lines changed

7 files changed

+51
-9
lines changed

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ class FunctionCallTransformer(
239239
val groupMarker = rootMarkers[1]
240240

241241
val (keySchema, groupSchema) = if (groupBy != null) {
242-
val keySchema = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop)
243-
val groupSchema = PluginDataFrameSchema(groupBy.df.columns())
242+
val keySchema = groupBy.keys
243+
val groupSchema = groupBy.groups
244244
keySchema to groupSchema
245245
} else {
246246
PluginDataFrameSchema.EMPTY to PluginDataFrameSchema.EMPTY

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/ExpectedArgumentDelegates.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.plugin.impl
22

33
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter.*
4+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
45
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
56
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId
67
import kotlin.properties.PropertyDelegateProvider
@@ -35,3 +36,7 @@ internal fun <T> AbstractInterpreter<T>.ignore(
3536
): ExpectedArgumentProvider<Nothing?> =
3637
arg(name, lens = Interpreter.Id, defaultValue = Present(null))
3738

39+
internal fun <T> AbstractInterpreter<T>.groupBy(
40+
name: ArgumentName? = null
41+
): ExpectedArgumentProvider<GroupBy> = arg(name, lens = Interpreter.GroupBy)
42+

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/Interpreter.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ interface Interpreter<T> {
2626

2727
data object Schema : Lens
2828

29+
data object GroupBy : Lens
30+
2931
data object Id : Lens
3032

3133
// required to compute whether resulting schema should be inheritor of previous class or a new class

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,19 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
2121
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
2222
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
2323
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
24+
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
2425
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
26+
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
2527

26-
class GroupBy(val df: PluginDataFrameSchema, val keys: List<ColumnWithPathApproximation>, val moveToTop: Boolean)
28+
class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema)
2729

2830
class DataFrameGroupBy : AbstractInterpreter<GroupBy>() {
2931
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
3032
val Arguments.moveToTop: Boolean by arg(defaultValue = Present(true))
3133
val Arguments.cols: ColumnsResolver by arg()
3234

3335
override fun Arguments.interpret(): GroupBy {
34-
return GroupBy(receiver, cols.resolve(receiver), moveToTop)
36+
return GroupBy(keys = createPluginDataFrameSchema(cols.resolve(receiver), moveToTop), groups = receiver)
3537
}
3638
}
3739

@@ -52,7 +54,7 @@ class GroupByInto : AbstractInterpreter<Unit>() {
5254
}
5355

5456
class Aggregate : AbstractSchemaModificationInterpreter() {
55-
val Arguments.receiver: GroupBy by arg()
57+
val Arguments.receiver: GroupBy by groupBy()
5658
val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id)
5759
override fun Arguments.interpret(): PluginDataFrameSchema {
5860
return aggregate(
@@ -87,7 +89,7 @@ fun KotlinTypeFacade.aggregate(
8789
)
8890
}
8991

90-
val cols = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop).columns() + dsl.columns.map {
92+
val cols = groupBy.keys.columns() + dsl.columns.map {
9193
simpleColumnOf(it.name, it.type)
9294
}
9395
PluginDataFrameSchema(cols)
@@ -144,13 +146,13 @@ fun KotlinTypeFacade.createPluginDataFrameSchema(keys: List<ColumnWithPathApprox
144146
}
145147

146148
class GroupByToDataFrame : AbstractSchemaModificationInterpreter() {
147-
val Arguments.receiver: GroupBy by arg()
149+
val Arguments.receiver: GroupBy by groupBy()
148150
val Arguments.groupedColumnName: String? by arg(defaultValue = Present(null))
149151

150152
override fun Arguments.interpret(): PluginDataFrameSchema {
151-
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.df.columns()))
153+
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.groups.columns()))
152154
return PluginDataFrameSchema(
153-
createPluginDataFrameSchema(receiver.keys, receiver.moveToTop).columns() + grouped
155+
receiver.keys.columns() + grouped
154156
)
155157
}
156158
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.jetbrains.kotlin.fir.references.resolved
3636
import org.jetbrains.kotlin.fir.references.symbol
3737
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
3838
import org.jetbrains.kotlin.fir.resolve.fqName
39+
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
3940
import org.jetbrains.kotlin.fir.scopes.collectAllProperties
4041
import org.jetbrains.kotlin.fir.scopes.getProperties
4142
import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope
@@ -78,6 +79,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
7879
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
7980
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
8081
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnsResolver
82+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
8183
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.SingleColumnApproximation
8284
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
8385

@@ -277,6 +279,17 @@ fun <T> KotlinTypeFacade.interpret(
277279
}
278280
}
279281

282+
is Interpreter.GroupBy -> {
283+
assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) {
284+
"'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType"
285+
}
286+
287+
val resolvedType = it.expression.resolvedType.fullyExpandedType(session)
288+
val keys = pluginDataFrameSchema(resolvedType.typeArguments[0])
289+
val groups = pluginDataFrameSchema(resolvedType.typeArguments[1])
290+
Interpreter.Success(GroupBy(keys, groups))
291+
}
292+
280293
is Interpreter.Id -> {
281294
Interpreter.Success(it.expression)
282295
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val df = dataFrameOf("a", "b", "c")(1, 2, 3)
8+
9+
val groupBy = df.groupBy { a }
10+
11+
val df1 = groupBy.updateGroups { it.remove { a } }.toDataFrame()
12+
df1.compileTimeSchema().print()
13+
return "OK"
14+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,12 @@ public void testGroupBy_DataRow() {
220220
runTest("testData/box/groupBy_DataRow.kt");
221221
}
222222

223+
@Test
224+
@TestMetadata("groupBy_extractSchema.kt")
225+
public void testGroupBy_extractSchema() {
226+
runTest("testData/box/groupBy_extractSchema.kt");
227+
}
228+
223229
@Test
224230
@TestMetadata("groupBy_refine.kt")
225231
public void testGroupBy_refine() {

0 commit comments

Comments
 (0)