@@ -2,14 +2,18 @@ package org.jetbrains.kotlinx.dataframe.statistics
2
2
3
3
import io.kotest.assertions.throwables.shouldThrow
4
4
import io.kotest.matchers.doubles.shouldBeNaN
5
+ import io.kotest.matchers.floats.shouldBeNaN
5
6
import io.kotest.matchers.shouldBe
6
7
import org.jetbrains.kotlinx.dataframe.DataColumn
7
8
import org.jetbrains.kotlinx.dataframe.api.asSequence
8
9
import org.jetbrains.kotlinx.dataframe.api.columnOf
9
10
import org.jetbrains.kotlinx.dataframe.api.columnTypes
10
11
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
11
12
import org.jetbrains.kotlinx.dataframe.api.ddof_default
13
+ import org.jetbrains.kotlinx.dataframe.api.rowStd
12
14
import org.jetbrains.kotlinx.dataframe.api.std
15
+ import org.jetbrains.kotlinx.dataframe.api.stdFor
16
+ import org.jetbrains.kotlinx.dataframe.api.stdOf
13
17
import org.jetbrains.kotlinx.dataframe.impl.nothingType
14
18
import org.jetbrains.kotlinx.dataframe.math.std
15
19
import org.jetbrains.kotlinx.dataframe.type
@@ -59,6 +63,239 @@ class StdTests {
59
63
df.std { value } shouldBe expected
60
64
}
61
65
66
+ @Test
67
+ fun `std with different numeric types` () {
68
+ // Integer types
69
+ columnOf(1 , 2 , 3 , 4 , 5 ).std() shouldBe 1.5811388300841898
70
+ columnOf(1L , 2L , 3L , 4L , 5L ).std() shouldBe 1.5811388300841898
71
+ columnOf(1 .toShort(), 2 .toShort(), 3 .toShort(), 4 .toShort(), 5 .toShort()).std() shouldBe 1.5811388300841898
72
+ columnOf(1 .toByte(), 2 .toByte(), 3 .toByte(), 4 .toByte(), 5 .toByte()).std() shouldBe 1.5811388300841898
73
+
74
+ // Floating point types
75
+ columnOf(1.0 , 2.0 , 3.0 , 4.0 , 5.0 ).std() shouldBe 1.5811388300841898
76
+ columnOf(1.0f , 2.0f , 3.0f , 4.0f , 5.0f ).std() shouldBe 1.5811388300841898
77
+ }
78
+
79
+ @Test
80
+ fun `std with null` () {
81
+ val colWithNull = columnOf(1 , 2 , null , 4 , 5 )
82
+ colWithNull.std() shouldBe 1.8257418583505538
83
+ }
84
+
85
+ @Test
86
+ fun `std with mixed numeric type` () {
87
+ columnOf<Number >(1 , 2L , 3.0f , 4.0 , 5 .toShort()).std() shouldBe 1.5811388300841898
88
+ }
89
+
90
+ @Test
91
+ fun `std with just NaNs` () {
92
+ columnOf(Double .NaN , Double .NaN ).std().shouldBeNaN()
93
+
94
+ // With skipNaN=true and only NaN values, result should be NaN
95
+ columnOf(Double .NaN , Double .NaN ).std(skipNaN = true ).shouldBeNaN()
96
+ }
97
+
98
+ @Test
99
+ fun `std with nans and nulls` () {
100
+ // Std functions should return NaN if any value is NaN
101
+ columnOf(1.0 , 2.0 , Double .NaN , 4.0 , null ).std().shouldBeNaN()
102
+
103
+ // With skipNaN=true, NaN values should be ignored
104
+ columnOf(1.0 , 2.0 , Double .NaN , 4.0 , null ).std(skipNaN = true ) shouldBe 1.5275252316519465
105
+ }
106
+
107
+ @Test
108
+ fun `stdOf with transformer function` () {
109
+ // Test with strings that can be converted to numbers
110
+ val strings = columnOf(" 1" , " 2" , " 3" , " 4" , " 5" )
111
+ strings.stdOf { it.toInt() } shouldBe 1.5811388300841898
112
+ }
113
+
114
+ @Test
115
+ fun `stdOf with transformer function with nulls` () {
116
+ val stringsWithNull = columnOf(" 1" , " 2" , null , " 4" , " 5" )
117
+ stringsWithNull.stdOf { it?.toInt() } shouldBe 1.8257418583505538
118
+ }
119
+
120
+ @Test
121
+ fun `stdOf with transformer function with NaNs` () {
122
+ // Std functions should return NaN if any value is NaN
123
+ val mixedValues = columnOf(" 1.0" , " 2.0" , " NaN" , " 4.0" , " 5.0" )
124
+ mixedValues.stdOf {
125
+ val num = it.toDoubleOrNull()
126
+ if (num == null || num.isNaN()) Double .NaN else num
127
+ }.shouldBeNaN()
128
+
129
+ // With skipNaN=true, NaN values should be ignored
130
+ mixedValues.stdOf(skipNaN = true ) {
131
+ val num = it.toDoubleOrNull()
132
+ if (num == null || num.isNaN()) Double .NaN else num
133
+ } shouldBe 1.8257418583505538
134
+ }
135
+
136
+ @[Test Suppress (" ktlint:standard:argument-list-wrapping" )]
137
+ fun `rowStd with dataframe` () {
138
+ val df = dataFrameOf(
139
+ " a" , " b" , " c" ,
140
+ )(
141
+ 1 , 2 , 3 ,
142
+ 4 , 5 , 6 ,
143
+ 7 , 8 , 9 ,
144
+ )
145
+
146
+ // Calculate standard deviation across each row
147
+ df[0 ].rowStd() shouldBe 1.0
148
+ df[1 ].rowStd() shouldBe 1.0
149
+ df[2 ].rowStd() shouldBe 1.0
150
+ }
151
+
152
+ @[Test Suppress (" ktlint:standard:argument-list-wrapping" )]
153
+ fun `rowStd with dataframe and nulls` () {
154
+ val df = dataFrameOf(
155
+ " a" , " b" , " c" ,
156
+ )(
157
+ 1 , 2 , 3 ,
158
+ 4 , null , 6 ,
159
+ 7 , 8 , 9 ,
160
+ )
161
+
162
+ // Calculate standard deviation across each row
163
+ df[0 ].rowStd() shouldBe 1.0
164
+ df[1 ].rowStd() shouldBe 1.4142135623730951 // std of [4, 6]
165
+ df[2 ].rowStd() shouldBe 1.0
166
+ }
167
+
168
+ @[Test Suppress (" ktlint:standard:argument-list-wrapping" )]
169
+ fun `rowStd with dataframe and NaNs` () {
170
+ // Std functions should return NaN if any value is NaN
171
+ val dfWithNaN = dataFrameOf(
172
+ " a" , " b" , " c" ,
173
+ )(
174
+ 1.0 , Double .NaN , 3.0 ,
175
+ Double .NaN , 5.0 , 6.0 ,
176
+ 7.0 , 8.0 , Double .NaN ,
177
+ )
178
+
179
+ dfWithNaN[0 ].rowStd().shouldBeNaN()
180
+ dfWithNaN[1 ].rowStd().shouldBeNaN()
181
+ dfWithNaN[2 ].rowStd().shouldBeNaN()
182
+
183
+ // With skipNaN=true, NaN values should be ignored
184
+ dfWithNaN[0 ].rowStd(skipNaN = true ) shouldBe 1.4142135623730951 // std of [1.0, 3.0]
185
+ dfWithNaN[1 ].rowStd(skipNaN = true ) shouldBe 0.7071067811865476 // std of [5.0, 6.0]
186
+ dfWithNaN[2 ].rowStd(skipNaN = true ) shouldBe 0.7071067811865476 // std of [7.0, 8.0]
187
+ }
188
+
189
+ @[Test Suppress (" ktlint:standard:argument-list-wrapping" )]
190
+ fun `dataframe std` () {
191
+ val df = dataFrameOf(
192
+ " a" , " b" , " c" ,
193
+ )(
194
+ 1 , 2 , 3 ,
195
+ 4 , 5 , 6 ,
196
+ 7 , 8 , 9 ,
197
+ )
198
+
199
+ // Get row with standard deviations for each column
200
+ val stds = df.std()
201
+ stds[" a" ] shouldBe 3.0
202
+ stds[" b" ] shouldBe 3.0
203
+ stds[" c" ] shouldBe 3.0
204
+
205
+ // Test std for specific columns
206
+ val stdFor = df.stdFor(" a" , " c" )
207
+ stdFor[" a" ] shouldBe 3.0
208
+ stdFor[" c" ] shouldBe 3.0
209
+ }
210
+
211
+ @[Test Suppress (" ktlint:standard:argument-list-wrapping" )]
212
+ fun `dataframe stdOf` () {
213
+ val df = dataFrameOf(
214
+ " a" , " b" , " c" ,
215
+ )(
216
+ 1 , 2 , 3 ,
217
+ 4 , 5 , 6 ,
218
+ 7 , 8 , 9 ,
219
+ )
220
+
221
+ // Calculate standard deviation of a + c for each row
222
+ df.stdOf { " a" <Int >() + " c" <Int >() } shouldBe 6.0 // std of [4, 10, 16]
223
+ }
224
+
225
+ @[Test Suppress (" ktlint:standard:argument-list-wrapping" )]
226
+ fun `std with skipNaN for floating point numbers` () {
227
+ // Test with Float.NaN values
228
+ val floatWithNaN = columnOf(1.0f , 2.0f , Float .NaN , 4.0f , 5.0f )
229
+ floatWithNaN.std().shouldBeNaN() // Default behavior: NaN propagates
230
+ floatWithNaN.std(skipNaN = true ) shouldBe 1.8257418583505538 // Skip NaN values
231
+
232
+ // Test with Double.NaN values
233
+ val doubleWithNaN = columnOf(1.0 , 2.0 , Double .NaN , 4.0 , 5.0 )
234
+ doubleWithNaN.std().shouldBeNaN() // Default behavior: NaN propagates
235
+ doubleWithNaN.std(skipNaN = true ) shouldBe 1.8257418583505538 // Skip NaN values
236
+
237
+ // Test with multiple NaN values in different positions
238
+ val multipleNaN = columnOf(Float .NaN , 2.0f , Float .NaN , 4.0f , Float .NaN )
239
+ multipleNaN.std().shouldBeNaN() // Default behavior: NaN propagates
240
+ multipleNaN.std(skipNaN = true ) shouldBe 1.4142135623730951 // Skip NaN values
241
+
242
+ // Test with all NaN values
243
+ val allNaN = columnOf(Float .NaN , Float .NaN , Float .NaN )
244
+ allNaN.std().shouldBeNaN() // All values are NaN, so result is NaN
245
+ allNaN.std(skipNaN = true ).shouldBeNaN() // With skipNaN=true and only NaN values, result should be NaN
246
+
247
+ // Test with DataFrame containing NaN values
248
+ val dfWithNaN = dataFrameOf(
249
+ " a" , " b" , " c" ,
250
+ )(
251
+ 1.0 , Double .NaN , 3.0 ,
252
+ 4.0 , 5.0 , Float .NaN ,
253
+ Double .NaN , 8.0 , 9.0 ,
254
+ )
255
+
256
+ // Test DataFrame std with NaN values
257
+ val stdsWithNaN = dfWithNaN.std() // Default behavior
258
+ (stdsWithNaN[" a" ] as Double ).shouldBeNaN() // Contains NaN
259
+ (stdsWithNaN[" b" ] as Double ).shouldBeNaN() // Contains NaN
260
+ (stdsWithNaN[" c" ] as Double ).shouldBeNaN() // Contains NaN
261
+
262
+ // Test DataFrame std with skipNaN=true
263
+ val stdsSkipNaN = dfWithNaN.std(skipNaN = true ) // Skip NaN values
264
+ stdsSkipNaN[" a" ] shouldBe 2.1213203435596424 // std of [1.0, 4.0]
265
+ stdsSkipNaN[" b" ] shouldBe 2.1213203435596424 // std of [5.0, 8.0]
266
+ stdsSkipNaN[" c" ] shouldBe 4.242640687119285 // std of [3.0, 9.0]
267
+
268
+ // Test stdFor with skipNaN
269
+ val stdForWithNaN = dfWithNaN.stdFor(" a" , " c" ) // Default behavior
270
+ (stdForWithNaN[" a" ] as Double ).shouldBeNaN() // Contains NaN
271
+ (stdForWithNaN[" c" ] as Double ).shouldBeNaN() // Contains NaN
272
+
273
+ val stdForSkipNaN = dfWithNaN.stdFor(" a" , " c" , skipNaN = true ) // Skip NaN values
274
+ stdForSkipNaN[" a" ] shouldBe 2.1213203435596424 // std of [1.0, 4.0]
275
+ stdForSkipNaN[" c" ] shouldBe 4.242640687119285 // std of [3.0, 9.0]
276
+
277
+ // Test stdOf with transformation that might produce NaN values
278
+ val dfForTransform = dataFrameOf(
279
+ " a" , " b" ,
280
+ )(
281
+ 1.0 , 0.0 ,
282
+ 4.0 , 2.0 ,
283
+ 0.0 , 0.0 ,
284
+ )
285
+
286
+ // Division by zero produces NaN
287
+ dfForTransform.stdOf {
288
+ val b = " b" <Double >()
289
+ if (b == 0.0 ) Double .NaN else " a" <Double >() / b
290
+ }.shouldBeNaN() // Default behavior: NaN propagates
291
+
292
+ // Skip NaN values from division by zero
293
+ dfForTransform.stdOf(skipNaN = true ) {
294
+ val b = " b" <Double >()
295
+ if (b == 0.0 ) Double .NaN else " a" <Double >() / b
296
+ }.shouldBeNaN() // Only 4.0/2.0 = 2.0 is valid, std of a single value is NaN
297
+ }
298
+
62
299
@Test
63
300
fun `std on empty or nullable column` () {
64
301
val empty = DataColumn .createValueColumn(" " , emptyList<Nothing >(), nothingType(false ))
0 commit comments