From d3a9d113f6dfa0a8b2aa0a26d45984775bcab21c Mon Sep 17 00:00:00 2001 From: Nikita Klimenko Date: Thu, 7 Mar 2024 20:59:24 +0200 Subject: [PATCH] then operation in pivot column selection DSL inside aggregate --- .../org/jetbrains/kotlinx/dataframe/api/pivot.kt | 2 +- .../impl/aggregation/PivotInAggregateImpl.kt | 3 ++- .../org/jetbrains/kotlinx/dataframe/api/pivot.kt | 15 +++++++++++++++ .../org/jetbrains/kotlinx/dataframe/api/pivot.kt | 2 +- .../impl/aggregation/PivotInAggregateImpl.kt | 3 ++- .../org/jetbrains/kotlinx/dataframe/api/pivot.kt | 15 +++++++++++++++ 6 files changed, 36 insertions(+), 4 deletions(-) diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt index dc781de85b..ed0bced180 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt @@ -148,7 +148,7 @@ public fun GroupBy<*, G>.pivotCounts(vararg columns: KProperty<*>, inward: B // region pivot -public fun AggregateGroupedDsl.pivot(inward: Boolean = true, columns: ColumnsSelector): PivotGroupBy = +public fun AggregateGroupedDsl.pivot(inward: Boolean = true, columns: PivotColumnsSelector): PivotGroupBy = PivotInAggregateImpl(this, columns, inward) public fun AggregateGroupedDsl.pivot(vararg columns: String, inward: Boolean = true): PivotGroupBy = diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt index 05aa22240f..b6f0db135c 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnsSelector import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.aggregation.AggregateBody import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedDsl +import org.jetbrains.kotlinx.dataframe.api.PivotColumnsSelector import org.jetbrains.kotlinx.dataframe.api.PivotGroupBy import org.jetbrains.kotlinx.dataframe.impl.api.AggregatedPivot import org.jetbrains.kotlinx.dataframe.impl.api.aggregatePivot @@ -11,7 +12,7 @@ import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet internal data class PivotInAggregateImpl( val aggregator: AggregateGroupedDsl, - val columns: ColumnsSelector, + val columns: PivotColumnsSelector, val inward: Boolean?, val default: Any? = null ) : PivotGroupBy, AggregatableInternal { diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt index a27a2592db..37489ec6e3 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt @@ -153,4 +153,19 @@ class PivotTests { 1, -1, 5 ) } + + @Test + fun `pivot then in aggregate`() { + val df = dataFrameOf( + "category1" to List(12) { it % 3 }, + "category2" to List(12) { "category2_${it % 2}" }, + "category3" to List(12) { "category3_${it % 5}" }, + "value" to List(12) { it } + ) + + val df1 = df.groupBy("category1").aggregate { + pivot { "category2" then "category3" }.count() + } + df1 shouldBe df.pivot { "category2" then "category3" }.groupBy("category1").count() + } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt index dc781de85b..ed0bced180 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt @@ -148,7 +148,7 @@ public fun GroupBy<*, G>.pivotCounts(vararg columns: KProperty<*>, inward: B // region pivot -public fun AggregateGroupedDsl.pivot(inward: Boolean = true, columns: ColumnsSelector): PivotGroupBy = +public fun AggregateGroupedDsl.pivot(inward: Boolean = true, columns: PivotColumnsSelector): PivotGroupBy = PivotInAggregateImpl(this, columns, inward) public fun AggregateGroupedDsl.pivot(vararg columns: String, inward: Boolean = true): PivotGroupBy = diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt index 05aa22240f..b6f0db135c 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/PivotInAggregateImpl.kt @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnsSelector import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.aggregation.AggregateBody import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedDsl +import org.jetbrains.kotlinx.dataframe.api.PivotColumnsSelector import org.jetbrains.kotlinx.dataframe.api.PivotGroupBy import org.jetbrains.kotlinx.dataframe.impl.api.AggregatedPivot import org.jetbrains.kotlinx.dataframe.impl.api.aggregatePivot @@ -11,7 +12,7 @@ import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet internal data class PivotInAggregateImpl( val aggregator: AggregateGroupedDsl, - val columns: ColumnsSelector, + val columns: PivotColumnsSelector, val inward: Boolean?, val default: Any? = null ) : PivotGroupBy, AggregatableInternal { diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt index a27a2592db..37489ec6e3 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt @@ -153,4 +153,19 @@ class PivotTests { 1, -1, 5 ) } + + @Test + fun `pivot then in aggregate`() { + val df = dataFrameOf( + "category1" to List(12) { it % 3 }, + "category2" to List(12) { "category2_${it % 2}" }, + "category3" to List(12) { "category3_${it % 5}" }, + "value" to List(12) { it } + ) + + val df1 = df.groupBy("category1").aggregate { + pivot { "category2" then "category3" }.count() + } + df1 shouldBe df.pivot { "category2" then "category3" }.groupBy("category1").count() + } }