diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt index 050d8a7499..f0f8ae06e0 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt @@ -248,6 +248,8 @@ public fun DataFrame.add(body: AddDsl.() -> Unit): DataFrame { return dataFrameOf(this@add.columns() + dsl.columns).cast() } +@Refine +@Interpretable("GroupByAdd") public inline fun GroupBy.add( name: String, infer: Infer = Infer.Nulls, diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt index 7668541839..25bd79c5f5 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt @@ -239,8 +239,8 @@ class FunctionCallTransformer( val groupMarker = rootMarkers[1] val (keySchema, groupSchema) = if (groupBy != null) { - val keySchema = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop) - val groupSchema = PluginDataFrameSchema(groupBy.df.columns()) + val keySchema = groupBy.keys + val groupSchema = groupBy.groups keySchema to groupSchema } else { PluginDataFrameSchema.EMPTY to PluginDataFrameSchema.EMPTY diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/ExpectedArgumentDelegates.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/ExpectedArgumentDelegates.kt index 2fcc616f92..4bb4f4c7bf 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/ExpectedArgumentDelegates.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/ExpectedArgumentDelegates.kt @@ -1,6 +1,7 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter.* +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId import kotlin.properties.PropertyDelegateProvider @@ -35,3 +36,7 @@ internal fun AbstractInterpreter.ignore( ): ExpectedArgumentProvider = arg(name, lens = Interpreter.Id, defaultValue = Present(null)) +internal fun AbstractInterpreter.groupBy( + name: ArgumentName? = null +): ExpectedArgumentProvider = arg(name, lens = Interpreter.GroupBy) + diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/Interpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/Interpreter.kt index d47473f76f..239147caa6 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/Interpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/Interpreter.kt @@ -26,6 +26,8 @@ interface Interpreter { data object Schema : Lens + data object GroupBy : Lens + data object Id : Lens // required to compute whether resulting schema should be inheritor of previous class or a new class diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt index 06551c1cf1..5a88dee41e 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt @@ -27,6 +27,10 @@ data class PluginDataFrameSchema( } } +fun PluginDataFrameSchema.add(name: String, type: ConeKotlinType, context: KotlinTypeFacade): PluginDataFrameSchema { + return PluginDataFrameSchema(columns() + context.simpleColumnOf(name, type)) +} + private fun List.asString(indent: String = ""): String { return joinToString("\n") { val col = when (it) { diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 8ce8a8c696..0cdef0c63a 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -19,11 +19,14 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.Present import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn +import org.jetbrains.kotlinx.dataframe.plugin.impl.add import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.type -class GroupBy(val df: PluginDataFrameSchema, val keys: List, val moveToTop: Boolean) +class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) class DataFrameGroupBy : AbstractInterpreter() { val Arguments.receiver: PluginDataFrameSchema by dataFrame() @@ -31,7 +34,7 @@ class DataFrameGroupBy : AbstractInterpreter() { val Arguments.cols: ColumnsResolver by arg() override fun Arguments.interpret(): GroupBy { - return GroupBy(receiver, cols.resolve(receiver), moveToTop) + return GroupBy(keys = createPluginDataFrameSchema(cols.resolve(receiver), moveToTop), groups = receiver) } } @@ -52,7 +55,7 @@ class GroupByInto : AbstractInterpreter() { } class Aggregate : AbstractSchemaModificationInterpreter() { - val Arguments.receiver: GroupBy by arg() + val Arguments.receiver: GroupBy by groupBy() val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id) override fun Arguments.interpret(): PluginDataFrameSchema { return aggregate( @@ -87,7 +90,7 @@ fun KotlinTypeFacade.aggregate( ) } - val cols = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop).columns() + dsl.columns.map { + val cols = groupBy.keys.columns() + dsl.columns.map { simpleColumnOf(it.name, it.type) } PluginDataFrameSchema(cols) @@ -144,13 +147,23 @@ fun KotlinTypeFacade.createPluginDataFrameSchema(keys: List() { + val Arguments.receiver: GroupBy by groupBy() + val Arguments.name: String by arg() + val Arguments.type: TypeApproximation by type(name("expression")) + + override fun Arguments.interpret(): GroupBy { + return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this)) + } +} diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt index 2c8fe0098f..de690d00ba 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt @@ -36,6 +36,7 @@ import org.jetbrains.kotlin.fir.references.resolved import org.jetbrains.kotlin.fir.references.symbol import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol import org.jetbrains.kotlin.fir.resolve.fqName +import org.jetbrains.kotlin.fir.resolve.fullyExpandedType import org.jetbrains.kotlin.fir.scopes.collectAllProperties import org.jetbrains.kotlin.fir.scopes.getProperties import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope @@ -78,6 +79,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnsResolver +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy import org.jetbrains.kotlinx.dataframe.plugin.impl.api.SingleColumnApproximation import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation @@ -277,6 +279,17 @@ fun KotlinTypeFacade.interpret( } } + is Interpreter.GroupBy -> { + assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) { + "'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType" + } + + val resolvedType = it.expression.resolvedType.fullyExpandedType(session) + val keys = pluginDataFrameSchema(resolvedType.typeArguments[0]) + val groups = pluginDataFrameSchema(resolvedType.typeArguments[1]) + Interpreter.Success(GroupBy(keys, groups)) + } + is Interpreter.Id -> { Interpreter.Success(it.expression) } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 4d3db8733f..45dc022423 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -88,6 +88,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MoveAfter0 @@ -275,6 +276,7 @@ internal inline fun String.load(): T { "MoveToLeft1" -> MoveToLeft1() "MoveToRight0" -> MoveToRight0() "MoveAfter0" -> MoveAfter0() + "GroupByAdd" -> GroupByAdd() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/testData/box/groupByAdd.kt b/plugins/kotlin-dataframe/testData/box/groupByAdd.kt new file mode 100644 index 0000000000..68ccfda00f --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupByAdd.kt @@ -0,0 +1,42 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.api.groupBy +import org.jetbrains.kotlinx.dataframe.io.* + +enum class State { + Idle, + Productive, + Maintenance, +} + +class Event(val toolId: String, val state: State, val timestamp: Long) + +fun box(): String { + val tool1 = "tool_1" + val tool2 = "tool_2" + val tool3 = "tool_3" + + val events = listOf( + Event(tool1, State.Idle, 0), + Event(tool1, State.Productive, 5), + Event(tool2, State.Idle, 0), + Event(tool2, State.Maintenance, 10), + Event(tool2, State.Idle, 20), + Event(tool3, State.Idle, 0), + Event(tool3, State.Productive, 25), + ).toDataFrame() + + val lastTimestamp = events.maxOf { timestamp } + val groupBy = events + .groupBy { toolId } + .sortBy { timestamp } + .add("stateDuration") { + (next()?.timestamp ?: lastTimestamp) - timestamp + }.toDataFrame() + + groupBy.group[0].stateDuration + + groupBy.compareSchemas(strict = true) + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_extractSchema.kt b/plugins/kotlin-dataframe/testData/box/groupBy_extractSchema.kt new file mode 100644 index 0000000000..c7efc8a953 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_extractSchema.kt @@ -0,0 +1,14 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a", "b", "c")(1, 2, 3) + + val groupBy = df.groupBy { a } + + val df1 = groupBy.updateGroups { it.remove { a } }.toDataFrame() + df1.compileTimeSchema().print() + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 48887c251b..943be4285c 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -214,12 +214,24 @@ public void testGroupBy() { runTest("testData/box/groupBy.kt"); } + @Test + @TestMetadata("groupByAdd.kt") + public void testGroupByAdd() { + runTest("testData/box/groupByAdd.kt"); + } + @Test @TestMetadata("groupBy_DataRow.kt") public void testGroupBy_DataRow() { runTest("testData/box/groupBy_DataRow.kt"); } + @Test + @TestMetadata("groupBy_extractSchema.kt") + public void testGroupBy_extractSchema() { + runTest("testData/box/groupBy_extractSchema.kt"); + } + @Test @TestMetadata("groupBy_refine.kt") public void testGroupBy_refine() {