Skip to content

Commit b95b767

Browse files
authored
Merge pull request #836 from Kotlin/ermolenko/822
Improve dataframe sorting in KTNB UI by handling non-comparable columns
2 parents d60b5de + 3daf3e3 commit b95b767

File tree

4 files changed

+542
-34
lines changed

4 files changed

+542
-34
lines changed

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/KotlinNotebookPluginUtils.kt

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.jupyter
33
import org.jetbrains.kotlinx.dataframe.AnyCol
44
import org.jetbrains.kotlinx.dataframe.AnyFrame
55
import org.jetbrains.kotlinx.dataframe.AnyRow
6+
import org.jetbrains.kotlinx.dataframe.DataRow
67
import org.jetbrains.kotlinx.dataframe.api.Convert
78
import org.jetbrains.kotlinx.dataframe.api.FormatClause
89
import org.jetbrains.kotlinx.dataframe.api.FormattedFrame
@@ -25,12 +26,13 @@ import org.jetbrains.kotlinx.dataframe.api.Update
2526
import org.jetbrains.kotlinx.dataframe.api.at
2627
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
2728
import org.jetbrains.kotlinx.dataframe.api.frames
29+
import org.jetbrains.kotlinx.dataframe.api.getColumn
2830
import org.jetbrains.kotlinx.dataframe.api.into
29-
import org.jetbrains.kotlinx.dataframe.api.sortBy
31+
import org.jetbrains.kotlinx.dataframe.api.isComparable
32+
import org.jetbrains.kotlinx.dataframe.api.sortWith
3033
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
3134
import org.jetbrains.kotlinx.dataframe.api.values
3235
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
33-
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
3436
import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator
3537

3638
/**
@@ -62,6 +64,7 @@ public object KotlinNotebookPluginUtils {
6264

6365
/**
6466
* Sorts a dataframe-like object by multiple columns.
67+
* If a column type is not comparable, sorting by string representation is applied instead.
6568
*
6669
* @param dataFrameLike The dataframe-like object to sort.
6770
* @param columnPaths The list of columns to sort by. Each element in the list represents a column path
@@ -79,27 +82,78 @@ public object KotlinNotebookPluginUtils {
7982
}
8083

8184
/**
82-
* Sorts the given data frame by the specified columns.
85+
* Sorts a dataframe by multiple columns with specified sorting order for each column.
86+
* If a column type is not comparable, sorting by string representation is applied instead.
8387
*
84-
* @param df The data frame to be sorted.
85-
* @param columnPaths The paths of the columns to be sorted. Each path is represented as a list of strings.
86-
* @param isDesc A list of booleans indicating whether each column should be sorted in descending order.
87-
* The size of this list must be equal to the size of the columnPaths list.
88-
* @return The sorted data frame.
88+
* @param df The dataframe to be sorted.
89+
* @param columnPaths A list of column paths where each path is a list of strings representing the hierarchical path of the column.
90+
* @param isDesc A list of boolean values indicating whether each column should be sorted in descending order;
91+
* true for descending, false for ascending. The size of this list should match the size of `columnPaths`.
92+
* @return The sorted dataframe.
8993
*/
90-
public fun sortByColumns(df: AnyFrame, columnPaths: List<List<String>>, isDesc: List<Boolean>): AnyFrame =
91-
df.sortBy {
92-
require(columnPaths.all { it.isNotEmpty() })
93-
require(columnPaths.size == isDesc.size)
94+
public fun sortByColumns(df: AnyFrame, columnPaths: List<List<String>>, isDesc: List<Boolean>): AnyFrame {
95+
require(columnPaths.all { it.isNotEmpty() })
96+
require(columnPaths.size == isDesc.size)
97+
98+
val sortKeys = columnPaths.map { path ->
99+
ColumnPath(path)
100+
}
101+
102+
val comparator = createComparator(sortKeys, isDesc)
94103

95-
val sortKeys = columnPaths.map { path ->
96-
ColumnPath(path)
104+
return df.sortWith(comparator)
105+
}
106+
107+
private fun createComparator(sortKeys: List<ColumnPath>, isDesc: List<Boolean>): Comparator<DataRow<*>> {
108+
return Comparator { row1, row2 ->
109+
for ((key, desc) in sortKeys.zip(isDesc)) {
110+
val comparisonResult = if (row1.df().getColumn(key).isComparable()) {
111+
compareComparableValues(row1, row2, key, desc)
112+
} else {
113+
compareStringValues(row1, row2, key, desc)
114+
}
115+
// If a comparison result is non-zero, we have resolved the ordering
116+
if (comparisonResult != 0) return@Comparator comparisonResult
97117
}
118+
// All comparisons are equal
119+
0
120+
}
121+
}
98122

99-
(sortKeys zip isDesc).map { (key, desc) ->
100-
if (desc) key.desc() else key
101-
}.toColumnSet()
123+
@Suppress("UNCHECKED_CAST")
124+
private fun compareComparableValues(
125+
row1: DataRow<*>,
126+
row2: DataRow<*>,
127+
key: ColumnPath,
128+
desc: Boolean,
129+
): Int {
130+
val firstValue = row1.getValueOrNull(key) as Comparable<Any?>?
131+
val secondValue = row2.getValueOrNull(key) as Comparable<Any?>?
132+
133+
return when {
134+
firstValue == null && secondValue == null -> 0
135+
firstValue == null -> if (desc) 1 else -1
136+
secondValue == null -> if (desc) -1 else 1
137+
desc -> secondValue.compareTo(firstValue)
138+
else -> firstValue.compareTo(secondValue)
102139
}
140+
}
141+
142+
private fun compareStringValues(
143+
row1: DataRow<*>,
144+
row2: DataRow<*>,
145+
key: ColumnPath,
146+
desc: Boolean,
147+
): Int {
148+
val firstValue = (row1.getValueOrNull(key)?.toString() ?: "")
149+
val secondValue = (row2.getValueOrNull(key)?.toString() ?: "")
150+
151+
return if (desc) {
152+
secondValue.compareTo(firstValue)
153+
} else {
154+
firstValue.compareTo(secondValue)
155+
}
156+
}
103157

104158
internal fun isDataframeConvertable(dataframeLike: Any?): Boolean =
105159
when (dataframeLike) {

core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/RenderingTests.kt

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,206 @@ class RenderingTests : JupyterReplTestCase() {
406406
}
407407
}
408408

409+
@Test
410+
fun `test sortByColumns by int column`() {
411+
val json = executeScriptAndParseDataframeResult(
412+
"""
413+
val df = dataFrameOf("nums")(5, 4, 3, 2, 1)
414+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("nums")), listOf(false))
415+
KotlinNotebookPluginUtils.convertToDataFrame(res)
416+
""".trimIndent(),
417+
)
418+
419+
val rows = json[KOTLIN_DATAFRAME]!!.jsonArray
420+
json.extractColumn<Int>(0, "nums") shouldBe 1
421+
json.extractColumn<Int>(rows.size - 1, "nums") shouldBe 5
422+
}
423+
424+
internal inline fun <reified T> JsonObject.extractColumn(index: Int, fieldName: String): T {
425+
val element = this[KOTLIN_DATAFRAME]!!.jsonArray[index].jsonObject[fieldName]!!.jsonPrimitive
426+
return when (T::class) {
427+
String::class -> element.content as T
428+
Int::class -> element.int as T
429+
else -> throw IllegalArgumentException("Unsupported type")
430+
}
431+
}
432+
433+
@Test
434+
fun `test sortByColumns by multiple int columns`() {
435+
val json = executeScriptAndParseDataframeResult(
436+
"""
437+
data class Row(val a: Int, val b: Int)
438+
val df = listOf(Row(1, 1), Row(1, 2), Row(2, 3), Row(2, 4), Row(3, 5), Row(3, 6)).toDataFrame()
439+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("a"), listOf("b")), listOf(true, false))
440+
KotlinNotebookPluginUtils.convertToDataFrame(res)
441+
""".trimIndent(),
442+
)
443+
444+
json.extractColumn<Int>(0, "a") shouldBe 3
445+
json.extractColumn<Int>(0, "b") shouldBe 5
446+
json.extractColumn<Int>(5, "a") shouldBe 1
447+
json.extractColumn<Int>(5, "b") shouldBe 2
448+
}
449+
450+
@Test
451+
fun `test sortByColumns by single string column`() {
452+
val json = executeScriptAndParseDataframeResult(
453+
"""
454+
val df = dataFrameOf("letters")("e", "d", "c", "b", "a")
455+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("letters")), listOf(true))
456+
KotlinNotebookPluginUtils.convertToDataFrame(res)
457+
""".trimIndent(),
458+
)
459+
460+
json.extractColumn<String>(0, "letters") shouldBe "e"
461+
json.extractColumn<String>(4, "letters") shouldBe "a"
462+
}
463+
464+
@Test
465+
fun `test sortByColumns by multiple string columns`() {
466+
val json = executeScriptAndParseDataframeResult(
467+
"""
468+
data class Row(val first: String, val second: String)
469+
val df = listOf(Row("a", "b"), Row("a", "a"), Row("b", "b"), Row("b", "a")).toDataFrame()
470+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("first"), listOf("second")), listOf(false, true))
471+
KotlinNotebookPluginUtils.convertToDataFrame(res)
472+
""".trimIndent(),
473+
)
474+
475+
json.extractColumn<String>(0, "first") shouldBe "a"
476+
json.extractColumn<String>(0, "second") shouldBe "b"
477+
json.extractColumn<String>(3, "first") shouldBe "b"
478+
json.extractColumn<String>(3, "second") shouldBe "a"
479+
}
480+
481+
@Test
482+
fun `test sortByColumns by mix of int and string columns`() {
483+
val json = executeScriptAndParseDataframeResult(
484+
"""
485+
data class Row(val num: Int, val letter: String)
486+
val df = listOf(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b"), Row(3, "a")).toDataFrame()
487+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("num"), listOf("letter")), listOf(true, false))
488+
KotlinNotebookPluginUtils.convertToDataFrame(res)
489+
""".trimIndent(),
490+
)
491+
492+
json.extractColumn<Int>(0, "num") shouldBe 3
493+
json.extractColumn<String>(0, "letter") shouldBe "a"
494+
json.extractColumn<Int>(4, "num") shouldBe 1
495+
json.extractColumn<String>(4, "letter") shouldBe "b"
496+
}
497+
498+
@Test
499+
fun `test sortByColumns by multiple non-comparable column`() {
500+
val json = executeScriptAndParseDataframeResult(
501+
"""
502+
data class Person(val name: String, val age: Int) {
503+
override fun toString(): String {
504+
return age.toString()
505+
}
506+
}
507+
val df = dataFrameOf("urls", "person")(
508+
URL("https://example.com/a"), Person("Alice", 10),
509+
URL("https://example.com/b"), Person("Bob", 11),
510+
URL("https://example.com/a"), Person("Nick", 12),
511+
URL("https://example.com/b"), Person("Guy", 13),
512+
)
513+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("urls"), listOf("person")), listOf(false, true))
514+
KotlinNotebookPluginUtils.convertToDataFrame(res)
515+
""".trimIndent(),
516+
)
517+
518+
json.extractColumn<Int>(0, "person") shouldBe 12
519+
json.extractColumn<Int>(3, "person") shouldBe 11
520+
}
521+
522+
@Test
523+
fun `test sortByColumns by mix of comparable and non-comparable columns`() {
524+
val json = executeScriptAndParseDataframeResult(
525+
"""
526+
val df = dataFrameOf("urls", "id")(
527+
URL("https://example.com/a"), 1,
528+
URL("https://example.com/b"), 2,
529+
URL("https://example.com/a"), 2,
530+
URL("https://example.com/b"), 1,
531+
)
532+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("id"), listOf("urls")), listOf(true, true))
533+
KotlinNotebookPluginUtils.convertToDataFrame(res)
534+
""".trimIndent(),
535+
)
536+
537+
json.extractColumn<String>(0, "urls") shouldBe "https://example.com/b"
538+
json.extractColumn<Int>(0, "id") shouldBe 2
539+
json.extractColumn<String>(3, "urls") shouldBe "https://example.com/a"
540+
json.extractColumn<Int>(3, "id") shouldBe 1
541+
}
542+
543+
@Test
544+
fun `test sortByColumns by url column`() {
545+
val json = executeScriptAndParseDataframeResult(
546+
"""
547+
val df = dataFrameOf("urls")(
548+
URL("https://example.com/a"),
549+
URL("https://example.com/c"),
550+
URL("https://example.com/b"),
551+
URL("https://example.com/d")
552+
)
553+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("urls")), listOf(false))
554+
KotlinNotebookPluginUtils.convertToDataFrame(res)
555+
""".trimIndent(),
556+
)
557+
558+
json.extractColumn<String>(0, "urls") shouldBe "https://example.com/a"
559+
json.extractColumn<String>(1, "urls") shouldBe "https://example.com/b"
560+
json.extractColumn<String>(2, "urls") shouldBe "https://example.com/c"
561+
json.extractColumn<String>(3, "urls") shouldBe "https://example.com/d"
562+
}
563+
564+
@Test
565+
fun `test sortByColumns by column group children`() {
566+
val json = executeScriptAndParseDataframeResult(
567+
"""
568+
val df = dataFrameOf(
569+
"a" to listOf(5, 4, 3, 2, 1),
570+
"b" to listOf(1, 2, 3, 4, 5)
571+
)
572+
val res = KotlinNotebookPluginUtils.sortByColumns(df.group("a", "b").into("c"), listOf(listOf("c", "a")), listOf(false))
573+
KotlinNotebookPluginUtils.convertToDataFrame(res)
574+
""".trimIndent(),
575+
)
576+
577+
fun JsonObject.extractBFields(): List<Int> {
578+
val dataframe = this[KOTLIN_DATAFRAME]!!.jsonArray
579+
return dataframe.map { it.jsonObject["c"]!!.jsonObject["data"]!!.jsonObject["b"]!!.jsonPrimitive.int }
580+
}
581+
582+
val bFields = json.extractBFields()
583+
bFields shouldBe listOf(5, 4, 3, 2, 1)
584+
}
585+
586+
@Test
587+
fun `test sortByColumns for column that contains string and int`() {
588+
val json = executeScriptAndParseDataframeResult(
589+
"""
590+
val df = dataFrameOf("mixed")(
591+
5,
592+
"10",
593+
2,
594+
"4",
595+
"1"
596+
)
597+
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("mixed")), listOf(true))
598+
KotlinNotebookPluginUtils.convertToDataFrame(res)
599+
""".trimIndent(),
600+
)
601+
602+
json.extractColumn<String>(0, "mixed") shouldBe "5"
603+
json.extractColumn<String>(1, "mixed") shouldBe "4"
604+
json.extractColumn<String>(2, "mixed") shouldBe "2"
605+
json.extractColumn<String>(3, "mixed") shouldBe "10"
606+
json.extractColumn<String>(4, "mixed") shouldBe "1"
607+
}
608+
409609
companion object {
410610
/**
411611
* Set the system property for the IDE version needed for specific serialization testing purposes.

0 commit comments

Comments
 (0)