Skip to content

Commit 3f43197

Browse files
committed
overhauling median
1 parent 2e86d2f commit 3f43197

File tree

7 files changed

+324
-117
lines changed

7 files changed

+324
-117
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ public fun <T : Comparable<T>> DataColumn<T?>.maxOrNull(skipNaN: Boolean = skipN
3636

3737
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxBy(
3838
skipNaN: Boolean = skipNaNDefault,
39-
noinline selector: (T) -> R,
39+
crossinline selector: (T) -> R,
4040
): T & Any = maxByOrNull(skipNaN, selector).suggestIfNull("maxBy")
4141

4242
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxByOrNull(
4343
skipNaN: Boolean = skipNaNDefault,
44-
noinline selector: (T) -> R,
44+
crossinline selector: (T) -> R,
4545
): T? = Aggregators.max<R>(skipNaN).aggregateByOrNull(this, selector)
4646

4747
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxOf(

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt

Lines changed: 249 additions & 98 deletions
Large diffs are not rendered by default.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ public fun <T : Comparable<T>> DataColumn<T?>.minOrNull(skipNaN: Boolean = skipN
3636

3737
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minBy(
3838
skipNaN: Boolean = skipNaNDefault,
39-
noinline selector: (T) -> R,
39+
crossinline selector: (T) -> R,
4040
): T & Any = minByOrNull(skipNaN, selector).suggestIfNull("minBy")
4141

4242
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minByOrNull(
4343
skipNaN: Boolean = skipNaNDefault,
44-
noinline selector: (T) -> R,
44+
crossinline selector: (T) -> R,
4545
): T? = Aggregators.min<R>(skipNaN).aggregateByOrNull(this, selector)
4646

4747
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minOf(

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns
2020
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
2121
import org.jetbrains.kotlinx.dataframe.math.percentile
2222
import kotlin.reflect.KProperty
23+
import kotlin.reflect.typeOf
2324

2425
// region DataColumn
2526

@@ -52,7 +53,7 @@ public fun AnyRow.rowPercentile(percentile: Double): Any =
5253
rowPercentileOrNull(percentile).suggestIfNull("rowPercentile")
5354

5455
public inline fun <reified T : Comparable<T>> AnyRow.rowPercentileOfOrNull(percentile: Double): T? =
55-
valuesOf<T>().percentile(percentile)
56+
valuesOf<T>().percentile(percentile, typeOf<T>())
5657

5758
public inline fun <reified T : Comparable<T>> AnyRow.rowPercentileOf(percentile: Double): T =
5859
rowPercentileOfOrNull<T>(percentile).suggestIfNull("rowPercentileOf")

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

3-
import org.jetbrains.kotlinx.dataframe.api.skipNaN_default
3+
import org.jetbrains.kotlinx.dataframe.api.skipNaNDefault
44
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.ReducingAggregationHandler
55
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
66
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.AnyInputHandler
@@ -153,13 +153,23 @@ internal object Aggregators {
153153
}
154154
}
155155

156-
fun <T> median(): Aggregator<T & Any, T?>
156+
// T : primitive Number? -> Double
157+
// T : Comparable<T & Any>? -> T?
158+
fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
159+
where T : Comparable<T & Any>? =
160+
median.invoke(skipNaN).cast2()
161+
162+
// T : Comparable<T & Any>? -> T?
163+
fun <T> medianComparables(): Aggregator<T & Any, T?>
157164
where T : Comparable<T & Any>? =
158-
median.invoke(skipNaN_default).cast2()
165+
medianCommon<T>(skipNaNDefault).cast2()
159166

160-
fun <T> median(skipNaN: Boolean): Aggregator<T & Any, Double>
167+
// T : primitive Number? -> Double
168+
fun <T> medianNumbers(
169+
skipNaN: Boolean,
170+
): Aggregator<T & Any, Double>
161171
where T : Comparable<T & Any>?, T : Number? =
162-
median.invoke(skipNaN).cast2()
172+
medianCommon<T>(skipNaN).cast2()
163173

164174
// T: Comparable<T>? -> T
165175
@Suppress("UNCHECKED_CAST")

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@ internal inline fun <T> Aggregatable<T>.remainingColumns(
1717
crossinline predicate: (AnyCol) -> Boolean,
1818
): ColumnsSelector<T, Any?> = remainingColumnsSelector().filter { predicate(it.data) }
1919

20+
/**
21+
* Emulates selecting all columns whose values are comparable to each other.
22+
* These are columns of type `R` where `R : Comparable<R>`.
23+
*
24+
* There is no way to denote this generically in types, however,
25+
* hence the _fake_ type `Comparable<Any>` is used.
26+
* (`Comparable<Nothing>` would be more correct, but then the compiler complains)
27+
*/
2028
@Suppress("UNCHECKED_CAST")
21-
internal fun <T> Aggregatable<T>.intraComparableColumns(): ColumnsSelector<T, Comparable<Any?>> =
22-
remainingColumns { it.valuesAreComparable() } as ColumnsSelector<T, Comparable<Any?>>
29+
internal fun <T> Aggregatable<T>.intraComparableColumns(): ColumnsSelector<T, Comparable<Any>?> =
30+
remainingColumns { it.valuesAreComparable() } as ColumnsSelector<T, Comparable<Any>?>
2331

2432
@Suppress("UNCHECKED_CAST")
2533
internal fun <T> Aggregatable<T>.numberColumns(): ColumnsSelector<T, Number?> =

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/median.kt

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import org.jetbrains.kotlinx.dataframe.api.groupBy
88
import org.jetbrains.kotlinx.dataframe.api.mapToColumn
99
import org.jetbrains.kotlinx.dataframe.api.median
1010
import org.jetbrains.kotlinx.dataframe.api.medianOf
11-
import org.jetbrains.kotlinx.dataframe.api.rowMedian
11+
import org.jetbrains.kotlinx.dataframe.api.rowMedianOf
12+
import org.jetbrains.kotlinx.dataframe.statistics.myFun
1213
import org.junit.Test
14+
import kotlin.experimental.ExperimentalTypeInference
1315
import kotlin.reflect.typeOf
1416

1517
@Suppress("ktlint:standard:argument-list-wrapping")
@@ -36,12 +38,19 @@ class MedianTests {
3638

3739
@Test
3840
fun `median of two columns`() {
39-
val df = dataFrameOf("a", "b")(
40-
1, 4,
41-
2, 6,
42-
7, 7,
41+
val df = dataFrameOf("a", "b", "c")(
42+
1, 4, "a",
43+
2, 6, "b",
44+
7, 7, "c",
4345
)
44-
df.median("a", "b") shouldBe 5
46+
df.median("a", "b") shouldBe 5.0
47+
df.median { "a"<Int>() and "b"<Int>() } shouldBe 5.0
48+
df.median("c") shouldBe "b"
49+
50+
df.median { "c"<String>() } shouldBe "b"
51+
52+
df.median({ "c"<String>() }) shouldBe "b"
53+
df.median<_, String> { "c"<String>() } shouldBe "b"
4554
}
4655

4756
@Test
@@ -51,6 +60,34 @@ class MedianTests {
5160
2, 4,
5261
7, 7,
5362
)
54-
df.mapToColumn("", Infer.Type) { it.rowMedian() } shouldBe columnOf(2, 3, 7)
63+
df.mapToColumn("", Infer.Type) { it.rowMedianOf<Int>() } shouldBe columnOf(2, 3, 7)
5564
}
5665
}
66+
67+
fun <T> List<T>.myFun(): Int where T : Comparable<T> = TODO()
68+
fun <T> List<T>.myFun(): Double where T : Comparable<T>, T : Number = TODO()
69+
70+
fun <T> myFun(list: List<T>): Int where T : Comparable<T> = TODO()
71+
fun <T> myFun(list: List<T>): Double where T : Comparable<T>, T : Number = TODO()
72+
73+
@OptIn(ExperimentalTypeInference::class)
74+
@OverloadResolutionByLambdaReturnType
75+
//@JvmName("jnkjsdnf")
76+
fun <T> myFun(get : () -> T): Int where T : Comparable<T> = TODO()
77+
78+
@JvmName("jnkjsdnf")
79+
@OptIn(ExperimentalTypeInference::class)
80+
@OverloadResolutionByLambdaReturnType
81+
fun <T> myFun(get : () -> T): Double where T : Comparable<T>, T : Number = TODO()
82+
83+
fun main() {
84+
val res1 = listOf(1, 2, 3).myFun()
85+
val res2 = listOf("a", "b", "c").myFun()
86+
87+
val res3 = myFun(listOf(1, 2, 3))
88+
val res4 = myFun(listOf("a", "b", "c"))
89+
90+
val res5 = myFun { 1 }
91+
val res6 = myFun { "" }
92+
val res7 = myFun<String> { "" }
93+
}

0 commit comments

Comments
 (0)