Skip to content

Commit c817531

Browse files
authored
CumSum (#1152)
* reimplementation of cumSum: - removed big numbers - following type conversion of sum - added correct overloads - fixed behavior of float/double like how it was described on the documentation site * fixed netflix example notebook for cumsum
1 parent 5d22f48 commit c817531

File tree

9 files changed

+4041
-7617
lines changed

9 files changed

+4041
-7617
lines changed

core/api/core.api

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,6 +1740,18 @@ public final class org/jetbrains/kotlinx/dataframe/api/CumSumKt {
17401740
public static synthetic fun cumSum$default (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;[Ljava/lang/String;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;
17411741
public static synthetic fun cumSum$default (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;[Lkotlin/reflect/KProperty;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;
17421742
public static synthetic fun cumSum$default (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;
1743+
public static final fun cumSumByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1744+
public static synthetic fun cumSumByte$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1745+
public static final fun cumSumDouble (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1746+
public static synthetic fun cumSumDouble$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1747+
public static final fun cumSumFloat (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1748+
public static synthetic fun cumSumFloat$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1749+
public static final fun cumSumNullableByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1750+
public static synthetic fun cumSumNullableByte$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1751+
public static final fun cumSumNullableShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1752+
public static synthetic fun cumSumNullableShort$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1753+
public static final fun cumSumShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
1754+
public static synthetic fun cumSumShort$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
17431755
}
17441756

17451757
public final class org/jetbrains/kotlinx/dataframe/api/DataColumnArithmeticsKt {
@@ -1943,6 +1955,7 @@ public abstract interface class org/jetbrains/kotlinx/dataframe/api/DataSchemaEn
19431955

19441956
public final class org/jetbrains/kotlinx/dataframe/api/DefaultsKt {
19451957
public static final fun getDdofDefault ()I
1958+
public static final fun getDefaultCumSumSkipNA ()Z
19461959
public static final fun getSkipNaNDefault ()Z
19471960
}
19481961

@@ -6588,6 +6601,10 @@ public final class org/jetbrains/kotlinx/dataframe/keywords/SoftKeywords$Compani
65886601
public final fun getVALUES ()Ljava/util/List;
65896602
}
65906603

6604+
public final class org/jetbrains/kotlinx/dataframe/math/CumsumKt {
6605+
public static final fun getCumSumTypeConversion ()Lkotlin/jvm/functions/Function2;
6606+
}
6607+
65916608
public final class org/jetbrains/kotlinx/dataframe/math/MedianKt {
65926609
public static final fun medianOrNull (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;Z)Ljava/lang/Object;
65936610
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/Defaults.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,9 @@ internal val skipNaNDefault: Boolean = false
1616
*/
1717
@PublishedApi
1818
internal val ddofDefault: Int = 1
19+
20+
/**
21+
* whether to skip nulls and NaNs in the cumSum operation.
22+
*/
23+
@PublishedApi
24+
internal val defaultCumSumSkipNA: Boolean = true
Lines changed: 72 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
package org.jetbrains.kotlinx.dataframe.api
22

3-
import org.jetbrains.kotlinx.dataframe.AnyColumnReference
43
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
54
import org.jetbrains.kotlinx.dataframe.DataColumn
65
import org.jetbrains.kotlinx.dataframe.DataFrame
76
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
87
import org.jetbrains.kotlinx.dataframe.api.Select.SelectSelectingOptions
8+
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
99
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1010
import org.jetbrains.kotlinx.dataframe.documentation.DocumentationUrls
1111
import org.jetbrains.kotlinx.dataframe.documentation.ExcludeFromSources
12-
import org.jetbrains.kotlinx.dataframe.impl.nothingType
13-
import org.jetbrains.kotlinx.dataframe.impl.nullableNothingType
14-
import org.jetbrains.kotlinx.dataframe.math.cumSum
15-
import org.jetbrains.kotlinx.dataframe.math.defaultCumSumSkipNA
16-
import org.jetbrains.kotlinx.dataframe.typeClass
17-
import java.math.BigDecimal
18-
import java.math.BigInteger
12+
import org.jetbrains.kotlinx.dataframe.math.cumSumImpl
1913
import kotlin.reflect.KProperty
20-
import kotlin.reflect.typeOf
2114

2215
// region DataColumn
2316

@@ -28,13 +21,14 @@ import kotlin.reflect.typeOf
2821
* from the first cell to the last cell.
2922
*
3023
* __NOTE:__ If the column contains nullable values and [skipNA\] is set to `true`,
31-
* null values are skipped when computing the cumulative sum.
32-
* Otherwise, any null value encountered will propagate null values in the output from that point onward.
24+
* null and NaN values are skipped when computing the cumulative sum.
25+
* When false, all values after the first NA will be NaN (for Double and Float columns)
26+
* or null (for integer columns).
3327
*
3428
* {@get [CumSumDocs.CUMSUM_PARAM] @param [columns\]
3529
* The names of the columns to apply cumSum operation.}
3630
*
37-
* @param [skipNA\] Whether to skip null values (default: `true`).
31+
* @param [skipNA\] Whether to skip null and NaN values (default: `true`).
3832
*
3933
* @return A new {@get [CumSumDocs.DATA_TYPE]} of the same type with the cumulative sums.
4034
*
@@ -54,62 +48,62 @@ private interface CumSumDocs {
5448
* {@set [CumSumDocs.DATA_TYPE] [DataColumn]}
5549
* {@set [CumSumDocs.CUMSUM_PARAM]}
5650
*/
57-
public fun <T : Number?> DataColumn<T>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<T> =
58-
when (type()) {
59-
typeOf<Double>() -> cast<Double>().cumSum(skipNA).cast()
60-
61-
typeOf<Double?>() -> cast<Double?>().cumSum(skipNA).cast()
62-
63-
typeOf<Float>() -> cast<Float>().cumSum(skipNA).cast()
64-
65-
typeOf<Float?>() -> cast<Float?>().cumSum(skipNA).cast()
66-
67-
typeOf<Int>() -> cast<Int>().cumSum().cast()
68-
69-
// TODO cumSum for Byte returns Int but is converted back to T: Byte, Issue #558
70-
typeOf<Byte>() -> cast<Byte>().cumSum().map { it.toByte() }.cast()
71-
72-
// TODO cumSum for Short returns Int but is converted back to T: Short, Issue #558
73-
typeOf<Short>() -> cast<Short>().cumSum().map { it.toShort() }.cast()
74-
75-
typeOf<Int?>() -> cast<Int?>().cumSum(skipNA).cast()
76-
77-
// TODO cumSum for Byte? returns Int? but is converted back to T: Byte?, Issue #558
78-
typeOf<Byte?>() -> cast<Byte?>().cumSum(skipNA).map { it?.toByte() }.cast()
79-
80-
// TODO cumSum for Short? returns Int? but is converted back to T: Short?, Issue #558
81-
typeOf<Short?>() -> cast<Short?>().cumSum(skipNA).map { it?.toShort() }.cast()
82-
83-
typeOf<Long>() -> cast<Long>().cumSum().cast()
84-
85-
typeOf<Long?>() -> cast<Long?>().cumSum(skipNA).cast()
86-
87-
typeOf<BigInteger>() -> cast<BigInteger>().cumSum().cast()
51+
@JvmName("cumSumShort")
52+
public fun DataColumn<Short>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Int> =
53+
cumSumImpl(type(), skipNA).cast()
8854

89-
typeOf<BigInteger?>() -> cast<BigInteger?>().cumSum(skipNA).cast()
90-
91-
typeOf<BigDecimal>() -> cast<BigDecimal>().cumSum().cast()
55+
/**
56+
* {@include [CumSumDocs]}
57+
* {@set [CumSumDocs.DATA_TYPE] [DataColumn]}
58+
* {@set [CumSumDocs.CUMSUM_PARAM]}
59+
*/
60+
@JvmName("cumSumNullableShort")
61+
public fun DataColumn<Short?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Int?> =
62+
cumSumImpl(type(), skipNA).cast()
9263

93-
typeOf<BigDecimal?>() -> cast<BigDecimal?>().cumSum(skipNA).cast()
64+
/**
65+
* {@include [CumSumDocs]}
66+
* {@set [CumSumDocs.DATA_TYPE] [DataColumn]}
67+
* {@set [CumSumDocs.CUMSUM_PARAM]}
68+
*/
69+
@JvmName("cumSumByte")
70+
public fun DataColumn<Byte>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Int> =
71+
cumSumImpl(type(), skipNA).cast()
9472

95-
typeOf<Number?>(), typeOf<Number>() -> convertToDouble().cumSum(skipNA).cast()
73+
/**
74+
* {@include [CumSumDocs]}
75+
* {@set [CumSumDocs.DATA_TYPE] [DataColumn]}
76+
* {@set [CumSumDocs.CUMSUM_PARAM]}
77+
*/
78+
@JvmName("cumSumNullableByte")
79+
public fun DataColumn<Byte?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Int?> =
80+
cumSumImpl(type(), skipNA).cast()
9681

97-
// Cumsum for empty column or column with just null is itself
98-
nothingType, nullableNothingType -> this
82+
/**
83+
* {@include [CumSumDocs]}
84+
* {@set [CumSumDocs.DATA_TYPE] [DataColumn]}
85+
* {@set [CumSumDocs.CUMSUM_PARAM]}
86+
*/
87+
@JvmName("cumSumDouble")
88+
public fun DataColumn<Double?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Double> =
89+
cumSumImpl(type(), skipNA).cast()
9990

100-
else -> error("Cumsum for type ${type()} is not supported")
101-
}
91+
/**
92+
* {@include [CumSumDocs]}
93+
* {@set [CumSumDocs.DATA_TYPE] [DataColumn]}
94+
* {@set [CumSumDocs.CUMSUM_PARAM]}
95+
*/
96+
@JvmName("cumSumFloat")
97+
public fun DataColumn<Float?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Float> =
98+
cumSumImpl(type(), skipNA).cast()
10299

103-
private val supportedClasses = setOf(
104-
Double::class,
105-
Float::class,
106-
Int::class,
107-
Byte::class,
108-
Short::class,
109-
Long::class,
110-
BigInteger::class,
111-
BigDecimal::class,
112-
)
100+
/**
101+
* {@include [CumSumDocs]}
102+
* {@set [CumSumDocs.DATA_TYPE] [DataColumn]}
103+
* {@set [CumSumDocs.CUMSUM_PARAM]}
104+
*/
105+
public fun <T : Number?> DataColumn<T>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<T> =
106+
cumSumImpl(type(), skipNA).cast()
113107

114108
// endregion
115109

@@ -119,26 +113,25 @@ private val supportedClasses = setOf(
119113
* {@include [CumSumDocs]}
120114
* {@set [CumSumDocs.DATA_TYPE] [DataFrame]}
121115
*/
122-
public fun <T, C> DataFrame<T>.cumSum(
116+
public fun <T, C : Number?> DataFrame<T>.cumSum(
123117
skipNA: Boolean = defaultCumSumSkipNA,
124118
columns: ColumnsSelector<T, C>,
125-
): DataFrame<T> =
126-
convert(columns).to { if (it.typeClass in supportedClasses) it.cast<Number?>().cumSum(skipNA) else it }
119+
): DataFrame<T> = convert(columns).to { it.cumSum(skipNA) }
127120

128121
/**
129122
* {@include [CumSumDocs]}
130123
* {@set [CumSumDocs.DATA_TYPE] [DataFrame]}
131124
*/
132125
public fun <T> DataFrame<T>.cumSum(vararg columns: String, skipNA: Boolean = defaultCumSumSkipNA): DataFrame<T> =
133-
cumSum(skipNA) { columns.toColumnSet() }
126+
cumSum(skipNA) { columns.toColumnSet().cast() }
134127

135128
/**
136129
* {@include [CumSumDocs]}
137130
* {@set [CumSumDocs.DATA_TYPE] [DataFrame]}
138131
*/
139132
@AccessApiOverload
140133
public fun <T> DataFrame<T>.cumSum(
141-
vararg columns: AnyColumnReference,
134+
vararg columns: ColumnReference<Number?>,
142135
skipNA: Boolean = defaultCumSumSkipNA,
143136
): DataFrame<T> = cumSum(skipNA) { columns.toColumnSet() }
144137

@@ -147,8 +140,10 @@ public fun <T> DataFrame<T>.cumSum(
147140
* {@set [CumSumDocs.DATA_TYPE] [DataFrame]}
148141
*/
149142
@AccessApiOverload
150-
public fun <T> DataFrame<T>.cumSum(vararg columns: KProperty<*>, skipNA: Boolean = defaultCumSumSkipNA): DataFrame<T> =
151-
cumSum(skipNA) { columns.toColumnSet() }
143+
public fun <T> DataFrame<T>.cumSum(
144+
vararg columns: KProperty<Number?>,
145+
skipNA: Boolean = defaultCumSumSkipNA,
146+
): DataFrame<T> = cumSum(skipNA) { columns.toColumnSet() }
152147

153148
/**
154149
* {@include [CumSumDocs]}
@@ -157,7 +152,8 @@ public fun <T> DataFrame<T>.cumSum(vararg columns: KProperty<*>, skipNA: Boolean
157152
*/
158153
public fun <T> DataFrame<T>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataFrame<T> =
159154
cumSum(skipNA) {
160-
colsAtAnyDepth { !it.isColumnGroup() }
155+
// TODO keep at any depth?
156+
colsAtAnyDepth { it.isNumber() }.cast()
161157
}
162158

163159
// endregion
@@ -168,7 +164,7 @@ public fun <T> DataFrame<T>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataF
168164
* {@include [CumSumDocs]}
169165
* {@set [CumSumDocs.DATA_TYPE] [GroupBy]}
170166
*/
171-
public fun <T, G, C> GroupBy<T, G>.cumSum(
167+
public fun <T, G, C : Number?> GroupBy<T, G>.cumSum(
172168
skipNA: Boolean = defaultCumSumSkipNA,
173169
columns: ColumnsSelector<G, C>,
174170
): GroupBy<T, G> = updateGroups { cumSum(skipNA, columns) }
@@ -178,15 +174,15 @@ public fun <T, G, C> GroupBy<T, G>.cumSum(
178174
* {@set [CumSumDocs.DATA_TYPE] [GroupBy]}
179175
*/
180176
public fun <T, G> GroupBy<T, G>.cumSum(vararg columns: String, skipNA: Boolean = defaultCumSumSkipNA): GroupBy<T, G> =
181-
cumSum(skipNA) { columns.toColumnSet() }
177+
cumSum(skipNA) { columns.toColumnSet().cast() }
182178

183179
/**
184180
* {@include [CumSumDocs]}
185181
* {@set [CumSumDocs.DATA_TYPE] [GroupBy]}
186182
*/
187183
@AccessApiOverload
188184
public fun <T, G> GroupBy<T, G>.cumSum(
189-
vararg columns: AnyColumnReference,
185+
vararg columns: ColumnReference<Number?>,
190186
skipNA: Boolean = defaultCumSumSkipNA,
191187
): GroupBy<T, G> = cumSum(skipNA) { columns.toColumnSet() }
192188

@@ -196,7 +192,7 @@ public fun <T, G> GroupBy<T, G>.cumSum(
196192
*/
197193
@AccessApiOverload
198194
public fun <T, G> GroupBy<T, G>.cumSum(
199-
vararg columns: KProperty<*>,
195+
vararg columns: KProperty<Number?>,
200196
skipNA: Boolean = defaultCumSumSkipNA,
201197
): GroupBy<T, G> = cumSum(skipNA) { columns.toColumnSet() }
202198

@@ -207,7 +203,8 @@ public fun <T, G> GroupBy<T, G>.cumSum(
207203
*/
208204
public fun <T, G> GroupBy<T, G>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): GroupBy<T, G> =
209205
cumSum(skipNA) {
210-
colsAtAnyDepth { !it.isColumnGroup() }
206+
// TODO keep at any depth?
207+
colsAtAnyDepth { it.isNumber() }.cast()
211208
}
212209

213210
// endregion

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ internal object Aggregators {
208208
}
209209

210210
// T: Number -> T
211+
// Byte -> Int
212+
// Short -> Int
213+
// Nothing -> Double
211214
val sum by withOneOption { skipNaN: Boolean ->
212215
twoStepReducingForNumbers(sumTypeConversion) { type ->
213216
sum(type, skipNaN)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/inputHandlers/NumberInputHandler.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ internal class NumberInputHandler<out Return : Any?> : AggregatorInputHandler<Nu
105105
* If no valid unification can be found or the input is solely [Number]`(?)`, the type [Number]`(?)` is returned.
106106
*/
107107
override fun calculateValueType(valueTypes: Set<KType>): ValueType {
108-
val unifiedType = valueTypes.unifiedNumberTypeOrNull(UnifiedNumberTypeOptions.Companion.PRIMITIVES_ONLY)
108+
val unifiedType = valueTypes.unifiedNumberTypeOrNull(UnifiedNumberTypeOptions.PRIMITIVES_ONLY)
109109
?: typeOf<Number>().withNullability(valueTypes.any { it.isMarkedNullable })
110110

111111
if (unifiedType.isSubtypeOf(typeOf<Double?>()) &&

0 commit comments

Comments
 (0)