1+ @file:OptIn(ExperimentalTypeInference ::class )
2+ @file:Suppress(" LocalVariableName" )
3+
14package org.jetbrains.kotlinx.dataframe.api
25
36import org.jetbrains.kotlinx.dataframe.AnyRow
@@ -13,50 +16,97 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine
1316import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1417import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1518import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf
16- import org.jetbrains.kotlinx.dataframe.columns.values
17- import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
1819import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators
19- import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast
2020import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
2121import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2222import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
23- import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
23+ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
24+ import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns
2425import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
25- import org.jetbrains.kotlinx.dataframe.impl.zero
26- import org.jetbrains.kotlinx.dataframe.math.sum
27- import org.jetbrains.kotlinx.dataframe.math.sumOf
26+ import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
27+ import kotlin.experimental.ExperimentalTypeInference
28+ import kotlin.reflect.KClass
2829import kotlin.reflect.KProperty
29- import kotlin.reflect.full.isSubtypeOf
30+ import kotlin.reflect.KType
3031import kotlin.reflect.typeOf
3132
33+ /* TODO KDocs
34+ * Calculating the sum is supported for all primitive number types.
35+ * Nulls are filtered out.
36+ * The return type is always the same as the input type (never null), except for `Byte` and `Short`,
37+ * which are converted to `Int`.
38+ * Empty input will result in 0 in the supplied number type.
39+ * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the sum.
40+ */
41+
3242// region DataColumn
3343
34- @JvmName(" sumT" )
35- public fun <T : Number > DataColumn<T>.sum (): T = values.sum(type())
44+ @JvmName(" sumShort" )
45+ public fun DataColumn<Short?>.sum (): Int = Aggregators .sum.aggregate(this ) as Int
46+
47+ @JvmName(" sumByte" )
48+ public fun DataColumn<Byte?>.sum (): Int = Aggregators .sum.aggregate(this ) as Int
3649
37- @JvmName(" sumTNullable" )
38- public fun <T : Number > DataColumn<T?>.sum (): T = values.sum(type())
50+ @Suppress(" UNCHECKED_CAST" )
51+ @JvmName(" sumNumber" )
52+ public fun <T : Number > DataColumn<T?>.sum (): T = Aggregators .sum.aggregate(this ) as T
3953
40- public inline fun <T , reified R : Number > DataColumn<T>.sumOf (noinline expression : (T ) -> R ): R ? =
41- (Aggregators .sum as Aggregator <* , * >).cast<R >().aggregateOf(this , expression)
54+ @JvmName(" sumOfShort" )
55+ @OverloadResolutionByLambdaReturnType
56+ public fun <C > DataColumn<C>.sumOf (expression : (C ) -> Short? ): Int =
57+ Aggregators .sum.aggregateOf(this , expression) as Int
58+
59+ @JvmName(" sumOfByte" )
60+ @OverloadResolutionByLambdaReturnType
61+ public fun <C > DataColumn<C>.sumOf (expression : (C ) -> Byte? ): Int = Aggregators .sum.aggregateOf(this , expression) as Int
62+
63+ @JvmName(" sumOfNumber" )
64+ @OverloadResolutionByLambdaReturnType
65+ public inline fun <C , reified V : Number > DataColumn<C>.sumOf (crossinline expression : (C ) -> V ? ): V =
66+ Aggregators .sum.aggregateOf(this , expression) as V
4267
4368// endregion
4469
4570// region DataRow
4671
47- public fun AnyRow.rowSum (): Number =
48- Aggregators .sum.aggregateCalculatingType(
49- values = values().filterIsInstance< Number >(),
50- valueTypes = columnTypes().filter { it.isSubtypeOf(typeOf< Number ?>()) }.toSet(),
51- ) ? : 0
72+ public fun AnyRow.rowSum (): Number = Aggregators .sum.aggregateOfRow( this , primitiveOrMixedNumberColumns())
73+
74+ @JvmName( " rowSumOfShort " )
75+ public inline fun < reified T : Short ? > AnyRow. rowSumOf ( _kClass : KClass < Short > = Short : :class): Int =
76+ rowSumOf(typeOf< T >()) as Int
5277
53- public inline fun <reified T : Number > AnyRow.rowSumOf (): T = values().filterIsInstance<T >().sum(typeOf<T >())
78+ @JvmName(" rowSumOfByte" )
79+ public inline fun <reified T : Byte ? > AnyRow.rowSumOf (_kClass : KClass <Byte > = Byte : :class): Int =
80+ rowSumOf(typeOf<T >()) as Int
5481
82+ @JvmName(" rowSumOfInt" )
83+ public inline fun <reified T : Int ? > AnyRow.rowSumOf (_kClass : KClass <Int > = Int : :class): Int =
84+ rowSumOf(typeOf<T >()) as Int
85+
86+ @JvmName(" rowSumOfLong" )
87+ public inline fun <reified T : Long ? > AnyRow.rowSumOf (_kClass : KClass <Long > = Long : :class): Long =
88+ rowSumOf(typeOf<T >()) as Long
89+
90+ @JvmName(" rowSumOfFloat" )
91+ public inline fun <reified T : Float ? > AnyRow.rowSumOf (_kClass : KClass <Float > = Float : :class): Float =
92+ rowSumOf(typeOf<T >()) as Float
93+
94+ @JvmName(" rowSumOfDouble" )
95+ public inline fun <reified T : Double ? > AnyRow.rowSumOf (_kClass : KClass <Double > = Double : :class): Double =
96+ rowSumOf(typeOf<T >()) as Double
97+
98+ // unfortunately, we cannot make a `reified T : Number?` due to clashes
99+ public fun AnyRow.rowSumOf (type : KType ): Number {
100+ require(type.isPrimitiveOrMixedNumber()) {
101+ " Type $type is not a primitive number type. Mean only supports primitive number types."
102+ }
103+ return Aggregators .sum.aggregateOfRow(this ) { colsOf(type) }
104+ }
55105// endregion
56106
57107// region DataFrame
58108
59- public fun <T > DataFrame<T>.sum (): DataRow <T > = sumFor(numberColumns ())
109+ public fun <T > DataFrame<T>.sum (): DataRow <T > = sumFor(primitiveOrMixedNumberColumns ())
60110
61111public fun <T , C : Number > DataFrame<T>.sumFor (columns : ColumnsForAggregateSelector <T , C ?>): DataRow <T > =
62112 Aggregators .sum.aggregateFor(this , columns)
@@ -71,28 +121,70 @@ public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: ColumnReference<C
71121public fun <T , C : Number > DataFrame<T>.sumFor (vararg columns : KProperty <C ?>): DataRow <T > =
72122 sumFor { columns.toColumnSet() }
73123
124+ @JvmName(" sumShort" )
125+ @OverloadResolutionByLambdaReturnType
126+ public fun <T > DataFrame<T>.sum (columns : ColumnsSelector <T , Short ?>): Int =
127+ Aggregators .sum.aggregateAll(this , columns) as Int
128+
129+ @JvmName(" sumByte" )
130+ @OverloadResolutionByLambdaReturnType
131+ public fun <T > DataFrame<T>.sum (columns : ColumnsSelector <T , Byte ?>): Int =
132+ Aggregators .sum.aggregateAll(this , columns) as Int
133+
134+ @JvmName(" sumNumber" )
135+ @OverloadResolutionByLambdaReturnType
74136public inline fun <T , reified C : Number > DataFrame<T>.sum (noinline columns : ColumnsSelector <T , C ?>): C =
75- ( Aggregators .sum.aggregateAll(this , columns) as C ? ) ? : C :: class .zero()
137+ Aggregators .sum.aggregateAll(this , columns) as C
76138
139+ @JvmName(" sumShort" )
140+ @AccessApiOverload
141+ public fun <T > DataFrame<T>.sum (vararg columns : ColumnReference <Short ?>): Int = sum { columns.toColumnSet() }
142+
143+ @JvmName(" sumByte" )
144+ @AccessApiOverload
145+ public fun <T > DataFrame<T>.sum (vararg columns : ColumnReference <Byte ?>): Int = sum { columns.toColumnSet() }
146+
147+ @JvmName(" sumNumber" )
77148@AccessApiOverload
78149public inline fun <T , reified C : Number > DataFrame<T>.sum (vararg columns : ColumnReference <C ?>): C =
79150 sum { columns.toColumnSet() }
80151
81- public fun <T > DataFrame<T>.sum (vararg columns : String ): Number = sum { columns.toColumnsSetOf() }
152+ public fun <T > DataFrame<T>.sum (vararg columns : String ): Number = sum { columns.toColumnsSetOf<Number ?>() }
153+
154+ @JvmName(" sumShort" )
155+ @AccessApiOverload
156+ public fun <T > DataFrame<T>.sum (vararg columns : KProperty <Short ?>): Int = sum { columns.toColumnSet() }
82157
158+ @JvmName(" sumByte" )
159+ @AccessApiOverload
160+ public fun <T > DataFrame<T>.sum (vararg columns : KProperty <Byte ?>): Int = sum { columns.toColumnSet() }
161+
162+ @JvmName(" sumNumber" )
83163@AccessApiOverload
84164public inline fun <T , reified C : Number > DataFrame<T>.sum (vararg columns : KProperty <C ?>): C =
85165 sum { columns.toColumnSet() }
86166
87- public inline fun <T , reified C : Number ?> DataFrame<T>.sumOf (crossinline expression : RowExpression <T , C >): C =
88- rows().sumOf(typeOf<C >()) { expression(it, it) }
167+ @JvmName(" sumOfShort" )
168+ @OverloadResolutionByLambdaReturnType
169+ public fun <T > DataFrame<T>.sumOf (expression : RowExpression <T , Short ?>): Int =
170+ Aggregators .sum.aggregateOf(this , expression) as Int
171+
172+ @JvmName(" sumOfByte" )
173+ @OverloadResolutionByLambdaReturnType
174+ public fun <T > DataFrame<T>.sumOf (expression : RowExpression <T , Byte ?>): Int =
175+ Aggregators .sum.aggregateOf(this , expression) as Int
176+
177+ @JvmName(" sumOfNumber" )
178+ @OverloadResolutionByLambdaReturnType
179+ public inline fun <T , reified C : Number > DataFrame<T>.sumOf (crossinline expression : RowExpression <T , C ?>): C =
180+ Aggregators .sum.aggregateOf(this , expression) as C
89181
90182// endregion
91183
92184// region GroupBy
93185@Refine
94186@Interpretable(" GroupBySum1" )
95- public fun <T > Grouped<T>.sum (): DataFrame <T > = sumFor(numberColumns ())
187+ public fun <T > Grouped<T>.sum (): DataFrame <T > = sumFor(primitiveOrMixedNumberColumns ())
96188
97189@Refine
98190@Interpretable(" GroupBySum0" )
@@ -136,7 +228,7 @@ public inline fun <T, reified R : Number> Grouped<T>.sumOf(
136228
137229// region Pivot
138230
139- public fun <T > Pivot<T>.sum (separate : Boolean = false): DataRow <T > = sumFor(separate, numberColumns ())
231+ public fun <T > Pivot<T>.sum (separate : Boolean = false): DataRow <T > = sumFor(separate, primitiveOrMixedNumberColumns ())
140232
141233public fun <T , R : Number > Pivot<T>.sumFor (
142234 separate : Boolean = false,
@@ -166,14 +258,15 @@ public fun <T, C : Number> Pivot<T>.sum(vararg columns: ColumnReference<C?>): Da
166258@AccessApiOverload
167259public fun <T , C : Number > Pivot<T>.sum (vararg columns : KProperty <C ?>): DataRow <T > = sum { columns.toColumnSet() }
168260
169- public inline fun <T , reified R : Number > Pivot<T>.sumOf (crossinline expression : RowExpression <T , R >): DataRow <T > =
261+ public inline fun <T , reified R : Number > Pivot<T>.sumOf (crossinline expression : RowExpression <T , R ? >): DataRow <T > =
170262 delegate { sumOf(expression) }
171263
172264// endregion
173265
174266// region PivotGroupBy
175267
176- public fun <T > PivotGroupBy<T>.sum (separate : Boolean = false): DataFrame <T > = sumFor(separate, numberColumns())
268+ public fun <T > PivotGroupBy<T>.sum (separate : Boolean = false): DataFrame <T > =
269+ sumFor(separate, primitiveOrMixedNumberColumns())
177270
178271public fun <T , R : Number > PivotGroupBy<T>.sumFor (
179272 separate : Boolean = false,
@@ -209,7 +302,7 @@ public fun <T, C : Number> PivotGroupBy<T>.sum(vararg columns: KProperty<C?>): D
209302 sum { columns.toColumnSet() }
210303
211304public inline fun <T , reified R : Number > PivotGroupBy<T>.sumOf (
212- crossinline expression : RowExpression <T , R >,
305+ crossinline expression : RowExpression <T , R ? >,
213306): DataFrame <T > = Aggregators .sum.aggregateOf(this , expression)
214307
215308// endregion
0 commit comments