Skip to content

Commit f6d3309

Browse files
committed
initial rework of median without percentile
1 parent b3363ff commit f6d3309

File tree

3 files changed

+205
-17
lines changed

3 files changed

+205
-17
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

3+
import org.jetbrains.kotlinx.dataframe.api.skipNaN_default
34
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.ReducingAggregationHandler
45
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
56
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.AnyInputHandler
67
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.NumberInputHandler
78
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
89
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
910
import org.jetbrains.kotlinx.dataframe.math.indexOfMax
11+
import org.jetbrains.kotlinx.dataframe.math.indexOfMedian
1012
import org.jetbrains.kotlinx.dataframe.math.indexOfMin
1113
import org.jetbrains.kotlinx.dataframe.math.maxOrNull
1214
import org.jetbrains.kotlinx.dataframe.math.maxTypeConversion
1315
import org.jetbrains.kotlinx.dataframe.math.mean
1416
import org.jetbrains.kotlinx.dataframe.math.meanTypeConversion
15-
import org.jetbrains.kotlinx.dataframe.math.median
17+
import org.jetbrains.kotlinx.dataframe.math.medianConversion
18+
import org.jetbrains.kotlinx.dataframe.math.medianOrNull
1619
import org.jetbrains.kotlinx.dataframe.math.minOrNull
1720
import org.jetbrains.kotlinx.dataframe.math.minTypeConversion
1821
import org.jetbrains.kotlinx.dataframe.math.percentile
@@ -29,13 +32,23 @@ internal object Aggregators {
2932
private fun <Value : Return & Any, Return : Any?> twoStepSelectingForAny(
3033
getReturnType: CalculateReturnType,
3134
indexOfResult: IndexOfResult<Value>,
32-
stepOneReducer: Reducer<Value, Return>,
35+
stepOneSelector: Selector<Value, Return>,
3336
) = Aggregator(
34-
aggregationHandler = SelectingAggregationHandler(stepOneReducer, indexOfResult, getReturnType),
37+
aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType),
3538
inputHandler = AnyInputHandler(),
3639
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
3740
)
3841

42+
private fun <Value : Return & Any, Return : Any?> flattenSelectingForAny(
43+
getReturnType: CalculateReturnType,
44+
indexOfResult: IndexOfResult<Value>,
45+
selector: Selector<Value, Return>,
46+
) = Aggregator(
47+
aggregationHandler = SelectingAggregationHandler(selector, indexOfResult, getReturnType),
48+
inputHandler = AnyInputHandler(),
49+
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
50+
)
51+
3952
private fun <Value : Any, Return : Any?> twoStepReducingForAny(
4053
getReturnType: CalculateReturnType,
4154
stepOneReducer: Reducer<Value, Return>,
@@ -101,7 +114,7 @@ internal object Aggregators {
101114
private val min by withOneOption { skipNaN: Boolean ->
102115
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
103116
getReturnType = minTypeConversion,
104-
stepOneReducer = { type -> minOrNull(type, skipNaN) },
117+
stepOneSelector = { type -> minOrNull(type, skipNaN) },
105118
indexOfResult = { type -> indexOfMin(type, skipNaN) },
106119
)
107120
}
@@ -113,15 +126,15 @@ internal object Aggregators {
113126
private val max by withOneOption { skipNaN: Boolean ->
114127
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
115128
getReturnType = maxTypeConversion,
116-
stepOneReducer = { type -> maxOrNull(type, skipNaN) },
129+
stepOneSelector = { type -> maxOrNull(type, skipNaN) },
117130
indexOfResult = { type -> indexOfMax(type, skipNaN) },
118131
)
119132
}
120133

121134
// T: Number? -> Double
122-
val std by withTwoOptions { skipNA: Boolean, ddof: Int ->
135+
val std by withTwoOptions { skipNaN: Boolean, ddof: Int ->
123136
flattenReducingForNumbers(stdTypeConversion) { type ->
124-
std(type, skipNA, ddof)
137+
std(type, skipNaN, ddof)
125138
}
126139
}
127140

@@ -140,9 +153,21 @@ internal object Aggregators {
140153
}
141154
}
142155

156+
@JvmName("medianComparable")
157+
fun <T : Comparable<T & Any>?> median(): Aggregator<T & Any, T?> = median.invoke(skipNaN_default).cast2()
158+
159+
@JvmName("medianNumber")
160+
fun <T> median(skipNaN: Boolean): Aggregator<T & Any, Double> where T : Comparable<T & Any>?, T : Number? =
161+
median.invoke(skipNaN).cast2()
162+
143163
// T: Comparable<T>? -> T
144-
val median by flattenReducingForAny<Comparable<Any?>> { type ->
145-
asIterable().median(type)
164+
@Suppress("UNCHECKED_CAST")
165+
private val median by withOneOption { skipNaN: Boolean ->
166+
flattenSelectingForAny<Comparable<Any>, Comparable<Any>?>(
167+
getReturnType = medianConversion,
168+
selector = { type -> medianOrNull(type, skipNaN) as Comparable<Any>? },
169+
indexOfResult = { type -> indexOfMedian(type, skipNaN) },
170+
)
146171
}
147172

148173
// T: Number -> T
Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,177 @@
11
package org.jetbrains.kotlinx.dataframe.math
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
4+
import org.jetbrains.kotlinx.dataframe.api.isNaN
5+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
6+
import org.jetbrains.kotlinx.dataframe.impl.canBeNaN
7+
import org.jetbrains.kotlinx.dataframe.impl.isIntraComparable
8+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
9+
import org.jetbrains.kotlinx.dataframe.impl.nothingType
10+
import org.jetbrains.kotlinx.dataframe.impl.renderType
11+
import org.jetbrains.kotlinx.dataframe.math.quickSelect
12+
import java.math.BigDecimal
13+
import java.math.BigInteger
314
import kotlin.reflect.KType
15+
import kotlin.reflect.full.withNullability
416
import kotlin.reflect.typeOf
517

18+
private val logger = KotlinLogging.logger { }
19+
620
// TODO median always returns the same type, but this can be confusing for iterables of even length
721
// TODO (e.g. median of [1, 2] should be 1.5, but the type is Int, so it returns 1), Issue #558
22+
23+
/**
24+
* Returns the median of the comparable input:
25+
* - `null` if empty and primitive number
26+
* - `Double.NaN` if empty and primitive number
27+
* - `Double` if primitive number
28+
* - `Double.NaN` if ![skipNaN] and contains NaN
29+
* - (lower) middle else
30+
*
31+
* TODO migrate back to percentile when it's flexible enough
32+
*/
833
@PublishedApi
9-
internal inline fun <reified T : Comparable<T>> Iterable<T?>.median(type: KType = typeOf<T>()): T? =
10-
percentile(50.0, type)
34+
internal fun <T : Comparable<T>> Sequence<T>.medianOrNull(type: KType, skipNaN: Boolean): Any? {
35+
when {
36+
type.isMarkedNullable ->
37+
error("Encountered nullable type ${renderType(type)} in median function. This should not occur.")
38+
39+
!type.isIntraComparable() ->
40+
error(
41+
"Unable to compute the median for ${
42+
renderType(type)
43+
}. Only primitive numbers or self-comparables are supported.",
44+
)
45+
46+
type == typeOf<BigDecimal>() || type == typeOf<BigInteger>() ->
47+
throw IllegalArgumentException(
48+
"Cannot calculate the median for big numbers in DataFrame. Only primitive numbers are supported.",
49+
)
50+
51+
type == typeOf<Long>() ->
52+
logger.warn { "Converting Longs to Doubles to calculate the median, loss of precision may occur." }
53+
54+
// this means the sequence is empty
55+
type == nothingType -> return null
56+
}
57+
58+
// propagate NaN to return if they are not to be skipped
59+
if (type.canBeNaN && !skipNaN && any { it.isNaN }) return Double.NaN
60+
61+
val list = when {
62+
type.canBeNaN -> filter { !it.isNaN }
63+
else -> this
64+
}.toList()
65+
66+
val size = list.size
67+
if (size == 0) return if (type.isPrimitiveNumber()) Double.NaN else null
68+
69+
val isOdd = size % 2 != 0
70+
71+
val middleIndex = (size - 1) / 2
72+
val lower = list.quickSelect(middleIndex)
73+
val upper = list.quickSelect(middleIndex + 1)
74+
75+
// check for quickSelect
76+
if (isOdd && lower.compareTo(upper) != 0) {
77+
error("lower and upper median are not equal while list-size is odd. This should not happen.")
78+
}
79+
80+
return when {
81+
isOdd && type.isPrimitiveNumber() -> (lower as Number).toDouble()
82+
isOdd -> lower
83+
type == typeOf<Double>() -> (lower as Double + upper as Double) / 2.0
84+
type == typeOf<Float>() -> ((lower as Float).toDouble() + (upper as Float).toDouble()) / 2.0
85+
type == typeOf<Int>() -> ((lower as Int).toDouble() + (upper as Int).toDouble()) / 2.0
86+
type == typeOf<Short>() -> ((lower as Short).toDouble() + (upper as Short).toDouble()) / 2.0
87+
type == typeOf<Byte>() -> ((lower as Byte).toDouble() + (upper as Byte).toDouble()) / 2.0
88+
type == typeOf<Long>() -> ((lower as Long).toDouble() + (upper as Long).toDouble()) / 2.0
89+
else -> lower
90+
}
91+
}
92+
93+
/**
94+
* Primitive Number -> Double
95+
* T : Comparable<T> -> T?
96+
*/
97+
internal val medianConversion: CalculateReturnType = { type, isEmpty ->
98+
when {
99+
// uses linear interpolation, number 7 of Hyndman and Fan "Sample quantiles in statistical packages"
100+
type.isPrimitiveNumber() -> typeOf<Double>()
101+
102+
// closest rank method, preferring lower middle,
103+
// number 3 of Hyndman and Fan "Sample quantiles in statistical packages"
104+
type.isIntraComparable() -> type.withNullability(isEmpty)
105+
106+
else -> error("Can not calculate median for type ${renderType(type)}")
107+
}
108+
}
109+
110+
/**
111+
* Returns the index of the median of the comparable input:
112+
* - `-1` if empty or all `null`
113+
* - index of first NaN if ![skipNaN] and contains NaN
114+
* - index (lower) middle else
115+
* NOTE: For primitive numbers the `seq.elementAt(seq.indexOfMedian())` might be different from `seq.medianOrNull()`
116+
*
117+
* TODO migrate back to percentile when it's flexible enough
118+
*/
119+
internal fun <T : Comparable<T & Any>?> Sequence<T>.indexOfMedian(type: KType, skipNaN: Boolean): Int {
120+
val nonNullType = type.withNullability(false)
121+
when {
122+
!nonNullType.isIntraComparable() ->
123+
error(
124+
"Unable to compute the median for ${
125+
renderType(type)
126+
}. Only primitive numbers or self-comparables are supported.",
127+
)
128+
129+
nonNullType == typeOf<BigDecimal>() || nonNullType == typeOf<BigInteger>() ->
130+
throw IllegalArgumentException(
131+
"Cannot calculate the median for big numbers in DataFrame. Only primitive numbers are supported.",
132+
)
133+
134+
// this means the sequence is empty
135+
nonNullType == nothingType -> return -1
136+
}
137+
138+
// propagate NaN to return if they are not to be skipped
139+
if (nonNullType.canBeNaN && !skipNaN) {
140+
for ((i, it) in this.withIndex()) {
141+
if (it.isNaN) return i
142+
}
143+
}
144+
145+
val indexedSequence = this.mapIndexedNotNull { i, it ->
146+
if (it == null) {
147+
null
148+
} else {
149+
IndexedComparable(i, it)
150+
}
151+
}
152+
val list = when {
153+
nonNullType.canBeNaN -> indexedSequence.filterNot { it.value.isNaN }
154+
else -> indexedSequence
155+
}.toList()
156+
157+
val size = list.size
158+
if (size == 0) return -1
159+
160+
val isOdd = size % 2 != 0
161+
162+
val middleIndex = (size - 1) / 2
163+
val lower = list.quickSelect(middleIndex)
164+
val upper = list.quickSelect(middleIndex + 1)
165+
166+
// check for quickSelect
167+
if (isOdd && lower.compareTo(upper) != 0) {
168+
error("lower and upper median are not equal while list-size is odd. This should not happen.")
169+
}
170+
171+
return lower.index
172+
}
173+
174+
private data class IndexedComparable<T : Comparable<T>>(val index: Int, val value: T) :
175+
Comparable<IndexedComparable<T>> {
176+
override fun compareTo(other: IndexedComparable<T>): Int = value.compareTo(other.value)
177+
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/percentile.kt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@ import org.jetbrains.kotlinx.dataframe.impl.asList
44
import java.math.BigDecimal
55
import java.math.BigInteger
66
import kotlin.reflect.KType
7-
import kotlin.reflect.typeOf
87

98
@PublishedApi
10-
internal inline fun <reified T : Comparable<T>> Iterable<T?>.percentile(
11-
percentile: Double,
12-
type: KType = typeOf<T>(),
13-
): T? {
9+
internal fun <T : Comparable<T>> Iterable<T?>.percentile(percentile: Double, type: KType): T? {
1410
require(percentile in 0.0..100.0) { "Percentile must be in range [0, 100]" }
1511

1612
@Suppress("UNCHECKED_CAST")
@@ -26,7 +22,7 @@ internal inline fun <reified T : Comparable<T>> Iterable<T?>.percentile(
2622
val lower = list.quickSelect(index)
2723
val upper = list.quickSelect(index + 1)
2824

29-
return when (type.classifier) {
25+
return when (type) {
3026
Double::class -> ((lower as Double + upper as Double) / 2.0) as T
3127
Float::class -> ((lower as Float + upper as Float) / 2.0f) as T
3228
Int::class -> ((lower as Int + upper as Int) / 2) as T

0 commit comments

Comments
 (0)