Skip to content

Commit 96d3c83

Browse files
committed
started std rework
1 parent f7c9238 commit 96d3c83

File tree

6 files changed

+87
-157
lines changed

6 files changed

+87
-157
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/Defaults.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,16 @@ package org.jetbrains.kotlinx.dataframe.api
33
@PublishedApi
44
internal val skipNaN_default: Boolean = false
55

6+
/**
7+
* Default delta degrees of freedom for the standard deviation (std).
8+
*
9+
* The default is set to `1`,
10+
* meaning DataFrame uses [Bessel’s correction](https://en.wikipedia.org/wiki/Bessel%27s_correction) to calculate the
11+
* "unbiased sample standard deviation" by default.
12+
* This is also the standard in languages like R.
13+
*
14+
* This is different from the "population standard deviation" (where `ddof = 0`),
15+
* which is used in libraries like Numpy.
16+
*/
617
@PublishedApi
718
internal val ddof_default: Int = 1

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ internal object Aggregators {
6666
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
6767
)
6868

69+
private fun <Return : Number?> flattenReducingForNumbers(
70+
getReturnType: CalculateReturnType,
71+
reducer: Reducer<Number, Return>,
72+
) = Aggregator(
73+
aggregationHandler = ReducingAggregationHandler(reducer, getReturnType),
74+
inputHandler = NumberInputHandler(),
75+
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
76+
)
77+
6978
private fun <Return : Number?> twoStepReducingForNumbers(
7079
getReturnType: CalculateReturnType,
7180
reducer: Reducer<Number, Return>,
@@ -111,8 +120,8 @@ internal object Aggregators {
111120

112121
// T: Number? -> Double
113122
val std by withTwoOptions { skipNA: Boolean, ddof: Int ->
114-
flattenReducingForAny<Number, Double>(stdTypeConversion) { type ->
115-
asIterable().std(type, skipNA, ddof)
123+
flattenReducingForNumbers(stdTypeConversion) { type ->
124+
std(type, skipNA, ddof)
116125
}
117126
}
118127

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/corr.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import org.jetbrains.kotlinx.dataframe.api.isSuitableForCorr
1313
import org.jetbrains.kotlinx.dataframe.api.toValueColumn
1414
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
1515
import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath
16-
import org.jetbrains.kotlinx.dataframe.math.varianceAndMean
16+
import org.jetbrains.kotlinx.dataframe.math.calculateBasicStatsOrNull
1717
import org.jetbrains.kotlinx.dataframe.nrow
1818
import kotlin.math.sqrt
1919

@@ -51,7 +51,7 @@ internal fun <T, C, R> Corr<T, C>.corrImpl(otherColumns: ColumnsSelector<T, R>):
5151
}
5252

5353
val stdMeans = cols.mapValues {
54-
it.value.toList().varianceAndMean()
54+
it.value.toList().calculateBasicStatsOrNull()
5555
}
5656

5757
val cache = mutableMapOf<Pair<ColumnPath, ColumnPath>, Double>()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
@file:Suppress("DuplicatedCode")
2+
3+
package org.jetbrains.kotlinx.dataframe.math
4+
5+
import org.jetbrains.kotlinx.dataframe.api.ddof_default
6+
import org.jetbrains.kotlinx.dataframe.api.skipNaN_default
7+
import kotlin.math.sqrt
8+
9+
internal data class BasicStats(val count: Int, val mean: Double, val variance: Double)
10+
11+
/**
12+
* Calculates the standard deviation from a [BasicStats] with optional delta degrees of freedom.
13+
*
14+
* @param ddof delta degrees of freedom, the bias-correction of std.
15+
* Default is [ddof_default], so `ddof = 1`, the "unbiased sample standard deviation", but alternatively,
16+
* the "population standard deviation", so `ddof = 0`, can be used.
17+
*/
18+
internal fun BasicStats.std(ddof: Int): Double =
19+
if (count <= ddof) {
20+
Double.NaN
21+
} else {
22+
sqrt(variance / (count - ddof))
23+
}
24+
25+
/**
26+
* Creates [BasicStats] instance for [this] sequence.
27+
*
28+
* This contains the [count][BasicStats.count], [mean][BasicStats.mean], and [variance][BasicStats.variance] and
29+
* can be used to efficiently calculate the [standard deviation][std].
30+
*/
31+
internal fun Sequence<Double>.calculateBasicStatsOrNull(skipNaN: Boolean = skipNaN_default): BasicStats? {
32+
var count = 0
33+
var sum = .0
34+
for (element in this) {
35+
if (element.isNaN()) {
36+
if (skipNaN) {
37+
continue
38+
} else {
39+
return null
40+
}
41+
}
42+
sum += element
43+
count++
44+
}
45+
val mean = sum / count
46+
var variance = .0
47+
for (element in this) {
48+
if (element.isNaN()) continue
49+
val diff = element - mean
50+
variance += diff * diff
51+
}
52+
return BasicStats(count = count, mean = mean, variance = variance)
53+
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/std.kt

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,16 @@ import org.jetbrains.kotlinx.dataframe.impl.renderType
77
import java.math.BigDecimal
88
import java.math.BigInteger
99
import kotlin.reflect.KType
10-
import kotlin.reflect.full.withNullability
1110
import kotlin.reflect.typeOf
1211

1312
@Suppress("UNCHECKED_CAST")
1413
@PublishedApi
15-
internal fun <T : Number> Iterable<T?>.std(
16-
type: KType,
17-
skipNaN: Boolean = skipNaN_default,
18-
ddof: Int = ddof_default,
19-
): Double {
14+
internal fun <T : Number> Sequence<T?>.std(type: KType, skipNaN: Boolean, ddof: Int): Double {
2015
if (type.isMarkedNullable) {
21-
return when {
22-
skipNaN -> filterNotNull().std(type = type.withNullability(false), skipNaN = true, ddof = ddof)
23-
contains(null) -> Double.NaN
24-
else -> std(type = type.withNullability(false), skipNaN = false, ddof = ddof)
25-
}
16+
error("Encountered nullable type ${renderType(type)} in std function. This should not occur.")
2617
}
27-
return when (type.classifier) {
28-
Double::class -> (this as Iterable<Double>).std(skipNaN, ddof)
18+
return when (type) {
19+
typeOf<Double>() -> (this as Iterable<Double>).std(skipNaN, ddof)
2920
Float::class -> (this as Iterable<Float>).std(skipNaN, ddof)
3021
Int::class, Short::class, Byte::class -> (this as Iterable<Int>).std(ddof)
3122
Long::class -> (this as Iterable<Long>).std(ddof)
@@ -44,20 +35,20 @@ internal val stdTypeConversion: CalculateReturnType = { _, _ ->
4435

4536
@JvmName("doubleStd")
4637
internal fun Iterable<Double>.std(skipNaN: Boolean = skipNaN_default, ddof: Int = ddof_default): Double =
47-
varianceAndMean(skipNaN)?.std(ddof) ?: Double.NaN
38+
calculateBasicStatsOrNull(skipNaN)?.std(ddof) ?: Double.NaN
4839

4940
@JvmName("floatStd")
5041
internal fun Iterable<Float>.std(skipNaN: Boolean = skipNaN_default, ddof: Int = ddof_default): Double =
51-
varianceAndMean(skipNaN)?.std(ddof) ?: Double.NaN
42+
calculateBasicStatsOrNull(skipNaN)?.std(ddof) ?: Double.NaN
5243

5344
@JvmName("intStd")
54-
internal fun Iterable<Int>.std(ddof: Int = ddof_default): Double = varianceAndMean().std(ddof)
45+
internal fun Iterable<Int>.std(ddof: Int = ddof_default): Double = calculateBasicStatsOrNull().std(ddof)
5546

5647
@JvmName("longStd")
57-
internal fun Iterable<Long>.std(ddof: Int = ddof_default): Double = varianceAndMean().std(ddof)
48+
internal fun Iterable<Long>.std(ddof: Int = ddof_default): Double = calculateBasicStatsOrNull().std(ddof)
5849

5950
@JvmName("bigDecimalStd")
60-
internal fun Iterable<BigDecimal>.std(ddof: Int = ddof_default): Double = varianceAndMean().std(ddof)
51+
internal fun Iterable<BigDecimal>.std(ddof: Int = ddof_default): Double = calculateBasicStatsOrNull().std(ddof)
6152

6253
@JvmName("bigIntegerStd")
63-
internal fun Iterable<BigInteger>.std(ddof: Int = ddof_default): Double = varianceAndMean().std(ddof)
54+
internal fun Iterable<BigInteger>.std(ddof: Int = ddof_default): Double = calculateBasicStatsOrNull().std(ddof)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/stdMean.kt

Lines changed: 0 additions & 134 deletions
This file was deleted.

0 commit comments

Comments
 (0)