Skip to content

Commit b59cfb6

Browse files
authored
Merge pull request #204 from Kotlin/sort-grouped-df
Sort grouped df
2 parents 6f3773c + 8a46ded commit b59cfb6

File tree

5 files changed

+258
-12
lines changed

5 files changed

+258
-12
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ public interface SortDsl<out T> : ColumnsSelectionDsl<T> {
3333
public fun <C> KProperty<C?>.nullsLast(flag: Boolean = true): ColumnSet<C?> = toColumnAccessor().nullsLast(flag)
3434
}
3535

36+
/**
37+
* [SortColumnsSelector] is used to express or select multiple columns to sort by, represented by [ColumnSet]`<C>`,
38+
* using the context of [SortDsl]`<T>` as `this` and `it`.
39+
*
40+
* So:
41+
* ```kotlin
42+
* SortDsl<T>.(it: SortDsl<T>) -> ColumnSet<C>
43+
* ```
44+
*/
3645
public typealias SortColumnsSelector<T, C> = Selector<SortDsl<T>, ColumnSet<C>>
3746

3847
// region DataColumn

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DataFrameReceiver.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,17 @@ internal open class DataFrameReceiver<T>(
3636
private val unresolvedColumnsPolicy: UnresolvedColumnsPolicy
3737
) : DataFrameReceiverBase<T>(source.unbox()), SingleColumn<DataRow<T>> {
3838

39-
private fun <R> DataColumn<R>?.check(path: ColumnPath): DataColumn<R>? =
39+
private fun <R> DataColumn<R>?.check(path: ColumnPath): DataColumn<R> =
4040
when (this) {
4141
null -> when (unresolvedColumnsPolicy) {
42-
UnresolvedColumnsPolicy.Create, UnresolvedColumnsPolicy.Skip -> MissingColumnGroup<Any>(path, this@DataFrameReceiver).asDataColumn().cast()
42+
UnresolvedColumnsPolicy.Create, UnresolvedColumnsPolicy.Skip -> MissingColumnGroup<Any>(
43+
path,
44+
this@DataFrameReceiver
45+
).asDataColumn().cast()
46+
4347
UnresolvedColumnsPolicy.Fail -> error("Column $path not found")
4448
}
49+
4550
is MissingDataColumn -> this
4651
is ColumnGroup<*> -> ColumnGroupWithParent(null, this).asDataColumn().cast()
4752
else -> this

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/sort.kt

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,29 @@ import org.jetbrains.kotlinx.dataframe.columns.UnresolvedColumnsPolicy
1616
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
1717
import org.jetbrains.kotlinx.dataframe.impl.columns.addPath
1818
import org.jetbrains.kotlinx.dataframe.impl.columns.assertIsComparable
19+
import org.jetbrains.kotlinx.dataframe.impl.columns.missing.MissingColumnGroup
1920
import org.jetbrains.kotlinx.dataframe.impl.columns.resolve
2021
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumns
2122
import org.jetbrains.kotlinx.dataframe.kind
2223
import org.jetbrains.kotlinx.dataframe.nrow
2324

24-
internal fun <T, G> GroupBy<T, G>.sortByImpl(columns: SortColumnsSelector<G, *>): GroupBy<T, G> {
25-
return toDataFrame()
25+
@Suppress("UNCHECKED_CAST", "RemoveExplicitTypeArguments")
26+
internal fun <T, G> GroupBy<T, G>.sortByImpl(columns: SortColumnsSelector<G, *>): GroupBy<T, G> =
27+
toDataFrame()
28+
29+
// sort the individual groups by the columns specified
2630
.update { groups }
2731
.with { it.sortByImpl(UnresolvedColumnsPolicy.Skip, columns) }
32+
33+
// sort the groups by the columns specified (must be either be the keys column or "groups")
34+
// will do nothing if the columns specified are not the keys column or "groups"
2835
.sortByImpl(UnresolvedColumnsPolicy.Skip, columns as SortColumnsSelector<T, *>)
29-
.asGroupBy { it.getFrameColumn(groups.name()).castFrameColumn() }
30-
}
36+
37+
.asGroupBy { it.getFrameColumn(groups.name()).castFrameColumn<G>() }
3138

3239
internal fun <T, C> DataFrame<T>.sortByImpl(
3340
unresolvedColumnsPolicy: UnresolvedColumnsPolicy = UnresolvedColumnsPolicy.Fail,
34-
columns: SortColumnsSelector<T, C>
41+
columns: SortColumnsSelector<T, C>,
3542
): DataFrame<T> {
3643
val sortColumns = getSortColumns(columns, unresolvedColumnsPolicy)
3744
if (sortColumns.isEmpty()) return this
@@ -61,17 +68,17 @@ internal fun AnyCol.createComparator(nullsLast: Boolean): java.util.Comparator<I
6168

6269
internal fun <T, C> DataFrame<T>.getSortColumns(
6370
columns: SortColumnsSelector<T, C>,
64-
unresolvedColumnsPolicy: UnresolvedColumnsPolicy
65-
): List<SortColumnDescriptor<*>> {
66-
return columns.toColumns().resolve(this, unresolvedColumnsPolicy)
71+
unresolvedColumnsPolicy: UnresolvedColumnsPolicy,
72+
): List<SortColumnDescriptor<*>> =
73+
columns.toColumns().resolve(this, unresolvedColumnsPolicy)
74+
.filterNot { it.data is MissingColumnGroup<*> } // can appear using [DataColumn<R>?.check] with UnresolvedColumnsPolicy.Skip
6775
.map {
6876
when (val col = it.data) {
6977
is SortColumnDescriptor<*> -> col
7078
is ValueColumn<*> -> SortColumnDescriptor(col)
7179
else -> throw IllegalStateException("Can not use ${col.kind} as sort column")
7280
}
7381
}
74-
}
7582

7683
internal enum class SortFlag { Reversed, NullsLast }
7784

@@ -86,12 +93,14 @@ internal fun <C> ColumnWithPath<C>.addFlag(flag: SortFlag): ColumnWithPath<C> {
8693
SortFlag.NullsLast -> SortColumnDescriptor(col.column, col.direction, true)
8794
}
8895
}
96+
8997
is ValueColumn -> {
9098
when (flag) {
9199
SortFlag.Reversed -> SortColumnDescriptor(col, SortDirection.Desc)
92100
SortFlag.NullsLast -> SortColumnDescriptor(col, SortDirection.Asc, true)
93101
}
94102
}
103+
95104
else -> throw IllegalArgumentException("Can not apply sort flag to column kind ${col.kind}")
96105
}.addPath(path)
97106
}
@@ -103,7 +112,7 @@ internal class ColumnsWithSortFlag<C>(val column: ColumnSet<C>, val flag: SortFl
103112
internal class SortColumnDescriptor<C>(
104113
val column: ValueColumn<C>,
105114
val direction: SortDirection = SortDirection.Asc,
106-
val nullsLast: Boolean = false
115+
val nullsLast: Boolean = false,
107116
) : ValueColumn<C> by column
108117

109118
internal enum class SortDirection { Asc, Desc }
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package org.jetbrains.kotlinx.dataframe.api
2+
3+
import io.kotest.matchers.shouldBe
4+
import org.jetbrains.kotlinx.dataframe.DataFrame
5+
import org.jetbrains.kotlinx.dataframe.alsoDebug
6+
import org.jetbrains.kotlinx.dataframe.io.read
7+
import org.junit.Test
8+
9+
class SortGroupedDataframeTests {
10+
11+
@Test
12+
fun `Sorted grouped iris dataset`() {
13+
val irisData = DataFrame.read("src/test/resources/irisDataset.csv")
14+
irisData.alsoDebug()
15+
16+
irisData.groupBy("variety").let {
17+
it.sortBy("petal.length").toString() shouldBe
18+
it.sortBy { it["petal.length"] }.toString()
19+
}
20+
}
21+
22+
enum class State {
23+
Idle, Productive, Maintenance
24+
}
25+
26+
@Test
27+
fun test4() {
28+
class Event(val toolId: String, val state: State, val timestamp: Long)
29+
30+
val tool1 = "tool_1"
31+
val tool2 = "tool_2"
32+
val tool3 = "tool_3"
33+
34+
val events = listOf(
35+
Event(tool1, State.Idle, 0),
36+
Event(tool1, State.Productive, 5),
37+
Event(tool2, State.Idle, 0),
38+
Event(tool2, State.Maintenance, 10),
39+
Event(tool2, State.Idle, 20),
40+
Event(tool3, State.Idle, 0),
41+
Event(tool3, State.Productive, 25),
42+
).toDataFrame()
43+
44+
val lastTimestamp = events.maxOf { getValue<Long>("timestamp") }
45+
val groupBy = events
46+
.groupBy("toolId")
47+
.sortBy("timestamp")
48+
.add("stateDuration") {
49+
(next()?.getValue("timestamp") ?: lastTimestamp) - getValue<Long>("timestamp")
50+
}
51+
52+
groupBy.toDataFrame().alsoDebug()
53+
groupBy.schema().print()
54+
groupBy.keys.print()
55+
groupBy.keys[0].print()
56+
57+
val df1 = groupBy.updateGroups {
58+
val missingValues = State.values().asList().toDataFrame {
59+
"state" from { it }
60+
}
61+
62+
val df = it
63+
.fullJoin(missingValues, "state")
64+
.fillNulls("stateDuration")
65+
.with { 100L }
66+
67+
df.groupBy("state").sumFor("stateDuration")
68+
}
69+
70+
df1.toDataFrame().alsoDebug().isNotEmpty() shouldBe true
71+
}
72+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"sepal.length","sepal.width","petal.length","petal.width","variety"
2+
5.1,3.5,1.4,.2,"Setosa"
3+
4.9,3,1.4,.2,"Setosa"
4+
4.7,3.2,1.3,.2,"Setosa"
5+
4.6,3.1,1.5,.2,"Setosa"
6+
5,3.6,1.4,.2,"Setosa"
7+
5.4,3.9,1.7,.4,"Setosa"
8+
4.6,3.4,1.4,.3,"Setosa"
9+
5,3.4,1.5,.2,"Setosa"
10+
4.4,2.9,1.4,.2,"Setosa"
11+
4.9,3.1,1.5,.1,"Setosa"
12+
5.4,3.7,1.5,.2,"Setosa"
13+
4.8,3.4,1.6,.2,"Setosa"
14+
4.8,3,1.4,.1,"Setosa"
15+
4.3,3,1.1,.1,"Setosa"
16+
5.8,4,1.2,.2,"Setosa"
17+
5.7,4.4,1.5,.4,"Setosa"
18+
5.4,3.9,1.3,.4,"Setosa"
19+
5.1,3.5,1.4,.3,"Setosa"
20+
5.7,3.8,1.7,.3,"Setosa"
21+
5.1,3.8,1.5,.3,"Setosa"
22+
5.4,3.4,1.7,.2,"Setosa"
23+
5.1,3.7,1.5,.4,"Setosa"
24+
4.6,3.6,1,.2,"Setosa"
25+
5.1,3.3,1.7,.5,"Setosa"
26+
4.8,3.4,1.9,.2,"Setosa"
27+
5,3,1.6,.2,"Setosa"
28+
5,3.4,1.6,.4,"Setosa"
29+
5.2,3.5,1.5,.2,"Setosa"
30+
5.2,3.4,1.4,.2,"Setosa"
31+
4.7,3.2,1.6,.2,"Setosa"
32+
4.8,3.1,1.6,.2,"Setosa"
33+
5.4,3.4,1.5,.4,"Setosa"
34+
5.2,4.1,1.5,.1,"Setosa"
35+
5.5,4.2,1.4,.2,"Setosa"
36+
4.9,3.1,1.5,.2,"Setosa"
37+
5,3.2,1.2,.2,"Setosa"
38+
5.5,3.5,1.3,.2,"Setosa"
39+
4.9,3.6,1.4,.1,"Setosa"
40+
4.4,3,1.3,.2,"Setosa"
41+
5.1,3.4,1.5,.2,"Setosa"
42+
5,3.5,1.3,.3,"Setosa"
43+
4.5,2.3,1.3,.3,"Setosa"
44+
4.4,3.2,1.3,.2,"Setosa"
45+
5,3.5,1.6,.6,"Setosa"
46+
5.1,3.8,1.9,.4,"Setosa"
47+
4.8,3,1.4,.3,"Setosa"
48+
5.1,3.8,1.6,.2,"Setosa"
49+
4.6,3.2,1.4,.2,"Setosa"
50+
5.3,3.7,1.5,.2,"Setosa"
51+
5,3.3,1.4,.2,"Setosa"
52+
7,3.2,4.7,1.4,"Versicolor"
53+
6.4,3.2,4.5,1.5,"Versicolor"
54+
6.9,3.1,4.9,1.5,"Versicolor"
55+
5.5,2.3,4,1.3,"Versicolor"
56+
6.5,2.8,4.6,1.5,"Versicolor"
57+
5.7,2.8,4.5,1.3,"Versicolor"
58+
6.3,3.3,4.7,1.6,"Versicolor"
59+
4.9,2.4,3.3,1,"Versicolor"
60+
6.6,2.9,4.6,1.3,"Versicolor"
61+
5.2,2.7,3.9,1.4,"Versicolor"
62+
5,2,3.5,1,"Versicolor"
63+
5.9,3,4.2,1.5,"Versicolor"
64+
6,2.2,4,1,"Versicolor"
65+
6.1,2.9,4.7,1.4,"Versicolor"
66+
5.6,2.9,3.6,1.3,"Versicolor"
67+
6.7,3.1,4.4,1.4,"Versicolor"
68+
5.6,3,4.5,1.5,"Versicolor"
69+
5.8,2.7,4.1,1,"Versicolor"
70+
6.2,2.2,4.5,1.5,"Versicolor"
71+
5.6,2.5,3.9,1.1,"Versicolor"
72+
5.9,3.2,4.8,1.8,"Versicolor"
73+
6.1,2.8,4,1.3,"Versicolor"
74+
6.3,2.5,4.9,1.5,"Versicolor"
75+
6.1,2.8,4.7,1.2,"Versicolor"
76+
6.4,2.9,4.3,1.3,"Versicolor"
77+
6.6,3,4.4,1.4,"Versicolor"
78+
6.8,2.8,4.8,1.4,"Versicolor"
79+
6.7,3,5,1.7,"Versicolor"
80+
6,2.9,4.5,1.5,"Versicolor"
81+
5.7,2.6,3.5,1,"Versicolor"
82+
5.5,2.4,3.8,1.1,"Versicolor"
83+
5.5,2.4,3.7,1,"Versicolor"
84+
5.8,2.7,3.9,1.2,"Versicolor"
85+
6,2.7,5.1,1.6,"Versicolor"
86+
5.4,3,4.5,1.5,"Versicolor"
87+
6,3.4,4.5,1.6,"Versicolor"
88+
6.7,3.1,4.7,1.5,"Versicolor"
89+
6.3,2.3,4.4,1.3,"Versicolor"
90+
5.6,3,4.1,1.3,"Versicolor"
91+
5.5,2.5,4,1.3,"Versicolor"
92+
5.5,2.6,4.4,1.2,"Versicolor"
93+
6.1,3,4.6,1.4,"Versicolor"
94+
5.8,2.6,4,1.2,"Versicolor"
95+
5,2.3,3.3,1,"Versicolor"
96+
5.6,2.7,4.2,1.3,"Versicolor"
97+
5.7,3,4.2,1.2,"Versicolor"
98+
5.7,2.9,4.2,1.3,"Versicolor"
99+
6.2,2.9,4.3,1.3,"Versicolor"
100+
5.1,2.5,3,1.1,"Versicolor"
101+
5.7,2.8,4.1,1.3,"Versicolor"
102+
6.3,3.3,6,2.5,"Virginica"
103+
5.8,2.7,5.1,1.9,"Virginica"
104+
7.1,3,5.9,2.1,"Virginica"
105+
6.3,2.9,5.6,1.8,"Virginica"
106+
6.5,3,5.8,2.2,"Virginica"
107+
7.6,3,6.6,2.1,"Virginica"
108+
4.9,2.5,4.5,1.7,"Virginica"
109+
7.3,2.9,6.3,1.8,"Virginica"
110+
6.7,2.5,5.8,1.8,"Virginica"
111+
7.2,3.6,6.1,2.5,"Virginica"
112+
6.5,3.2,5.1,2,"Virginica"
113+
6.4,2.7,5.3,1.9,"Virginica"
114+
6.8,3,5.5,2.1,"Virginica"
115+
5.7,2.5,5,2,"Virginica"
116+
5.8,2.8,5.1,2.4,"Virginica"
117+
6.4,3.2,5.3,2.3,"Virginica"
118+
6.5,3,5.5,1.8,"Virginica"
119+
7.7,3.8,6.7,2.2,"Virginica"
120+
7.7,2.6,6.9,2.3,"Virginica"
121+
6,2.2,5,1.5,"Virginica"
122+
6.9,3.2,5.7,2.3,"Virginica"
123+
5.6,2.8,4.9,2,"Virginica"
124+
7.7,2.8,6.7,2,"Virginica"
125+
6.3,2.7,4.9,1.8,"Virginica"
126+
6.7,3.3,5.7,2.1,"Virginica"
127+
7.2,3.2,6,1.8,"Virginica"
128+
6.2,2.8,4.8,1.8,"Virginica"
129+
6.1,3,4.9,1.8,"Virginica"
130+
6.4,2.8,5.6,2.1,"Virginica"
131+
7.2,3,5.8,1.6,"Virginica"
132+
7.4,2.8,6.1,1.9,"Virginica"
133+
7.9,3.8,6.4,2,"Virginica"
134+
6.4,2.8,5.6,2.2,"Virginica"
135+
6.3,2.8,5.1,1.5,"Virginica"
136+
6.1,2.6,5.6,1.4,"Virginica"
137+
7.7,3,6.1,2.3,"Virginica"
138+
6.3,3.4,5.6,2.4,"Virginica"
139+
6.4,3.1,5.5,1.8,"Virginica"
140+
6,3,4.8,1.8,"Virginica"
141+
6.9,3.1,5.4,2.1,"Virginica"
142+
6.7,3.1,5.6,2.4,"Virginica"
143+
6.9,3.1,5.1,2.3,"Virginica"
144+
5.8,2.7,5.1,1.9,"Virginica"
145+
6.8,3.2,5.9,2.3,"Virginica"
146+
6.7,3.3,5.7,2.5,"Virginica"
147+
6.7,3,5.2,2.3,"Virginica"
148+
6.3,2.5,5,1.9,"Virginica"
149+
6.5,3,5.2,2,"Virginica"
150+
6.2,3.4,5.4,2.3,"Virginica"
151+
5.9,3,5.1,1.8,"Virginica"

0 commit comments

Comments
 (0)