Skip to content

Commit 1646c79

Browse files
committed
expanded std tests
1 parent c9f090b commit 1646c79

File tree

1 file changed

+237
-0
lines changed
  • core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics

1 file changed

+237
-0
lines changed

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/std.kt

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@ package org.jetbrains.kotlinx.dataframe.statistics
22

33
import io.kotest.assertions.throwables.shouldThrow
44
import io.kotest.matchers.doubles.shouldBeNaN
5+
import io.kotest.matchers.floats.shouldBeNaN
56
import io.kotest.matchers.shouldBe
67
import org.jetbrains.kotlinx.dataframe.DataColumn
78
import org.jetbrains.kotlinx.dataframe.api.asSequence
89
import org.jetbrains.kotlinx.dataframe.api.columnOf
910
import org.jetbrains.kotlinx.dataframe.api.columnTypes
1011
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
1112
import org.jetbrains.kotlinx.dataframe.api.ddof_default
13+
import org.jetbrains.kotlinx.dataframe.api.rowStd
1214
import org.jetbrains.kotlinx.dataframe.api.std
15+
import org.jetbrains.kotlinx.dataframe.api.stdFor
16+
import org.jetbrains.kotlinx.dataframe.api.stdOf
1317
import org.jetbrains.kotlinx.dataframe.impl.nothingType
1418
import org.jetbrains.kotlinx.dataframe.math.std
1519
import org.jetbrains.kotlinx.dataframe.type
@@ -59,6 +63,239 @@ class StdTests {
5963
df.std { value } shouldBe expected
6064
}
6165

66+
@Test
67+
fun `std with different numeric types`() {
68+
// Integer types
69+
columnOf(1, 2, 3, 4, 5).std() shouldBe 1.5811388300841898
70+
columnOf(1L, 2L, 3L, 4L, 5L).std() shouldBe 1.5811388300841898
71+
columnOf(1.toShort(), 2.toShort(), 3.toShort(), 4.toShort(), 5.toShort()).std() shouldBe 1.5811388300841898
72+
columnOf(1.toByte(), 2.toByte(), 3.toByte(), 4.toByte(), 5.toByte()).std() shouldBe 1.5811388300841898
73+
74+
// Floating point types
75+
columnOf(1.0, 2.0, 3.0, 4.0, 5.0).std() shouldBe 1.5811388300841898
76+
columnOf(1.0f, 2.0f, 3.0f, 4.0f, 5.0f).std() shouldBe 1.5811388300841898
77+
}
78+
79+
@Test
80+
fun `std with null`() {
81+
val colWithNull = columnOf(1, 2, null, 4, 5)
82+
colWithNull.std() shouldBe 1.8257418583505538
83+
}
84+
85+
@Test
86+
fun `std with mixed numeric type`() {
87+
columnOf<Number>(1, 2L, 3.0f, 4.0, 5.toShort()).std() shouldBe 1.5811388300841898
88+
}
89+
90+
@Test
91+
fun `std with just NaNs`() {
92+
columnOf(Double.NaN, Double.NaN).std().shouldBeNaN()
93+
94+
// With skipNaN=true and only NaN values, result should be NaN
95+
columnOf(Double.NaN, Double.NaN).std(skipNaN = true).shouldBeNaN()
96+
}
97+
98+
@Test
99+
fun `std with nans and nulls`() {
100+
// Std functions should return NaN if any value is NaN
101+
columnOf(1.0, 2.0, Double.NaN, 4.0, null).std().shouldBeNaN()
102+
103+
// With skipNaN=true, NaN values should be ignored
104+
columnOf(1.0, 2.0, Double.NaN, 4.0, null).std(skipNaN = true) shouldBe 1.5275252316519465
105+
}
106+
107+
@Test
108+
fun `stdOf with transformer function`() {
109+
// Test with strings that can be converted to numbers
110+
val strings = columnOf("1", "2", "3", "4", "5")
111+
strings.stdOf { it.toInt() } shouldBe 1.5811388300841898
112+
}
113+
114+
@Test
115+
fun `stdOf with transformer function with nulls`() {
116+
val stringsWithNull = columnOf("1", "2", null, "4", "5")
117+
stringsWithNull.stdOf { it?.toInt() } shouldBe 1.8257418583505538
118+
}
119+
120+
@Test
121+
fun `stdOf with transformer function with NaNs`() {
122+
// Std functions should return NaN if any value is NaN
123+
val mixedValues = columnOf("1.0", "2.0", "NaN", "4.0", "5.0")
124+
mixedValues.stdOf {
125+
val num = it.toDoubleOrNull()
126+
if (num == null || num.isNaN()) Double.NaN else num
127+
}.shouldBeNaN()
128+
129+
// With skipNaN=true, NaN values should be ignored
130+
mixedValues.stdOf(skipNaN = true) {
131+
val num = it.toDoubleOrNull()
132+
if (num == null || num.isNaN()) Double.NaN else num
133+
} shouldBe 1.8257418583505538
134+
}
135+
136+
@[Test Suppress("ktlint:standard:argument-list-wrapping")]
137+
fun `rowStd with dataframe`() {
138+
val df = dataFrameOf(
139+
"a", "b", "c",
140+
)(
141+
1, 2, 3,
142+
4, 5, 6,
143+
7, 8, 9,
144+
)
145+
146+
// Calculate standard deviation across each row
147+
df[0].rowStd() shouldBe 1.0
148+
df[1].rowStd() shouldBe 1.0
149+
df[2].rowStd() shouldBe 1.0
150+
}
151+
152+
@[Test Suppress("ktlint:standard:argument-list-wrapping")]
153+
fun `rowStd with dataframe and nulls`() {
154+
val df = dataFrameOf(
155+
"a", "b", "c",
156+
)(
157+
1, 2, 3,
158+
4, null, 6,
159+
7, 8, 9,
160+
)
161+
162+
// Calculate standard deviation across each row
163+
df[0].rowStd() shouldBe 1.0
164+
df[1].rowStd() shouldBe 1.4142135623730951 // std of [4, 6]
165+
df[2].rowStd() shouldBe 1.0
166+
}
167+
168+
@[Test Suppress("ktlint:standard:argument-list-wrapping")]
169+
fun `rowStd with dataframe and NaNs`() {
170+
// Std functions should return NaN if any value is NaN
171+
val dfWithNaN = dataFrameOf(
172+
"a", "b", "c",
173+
)(
174+
1.0, Double.NaN, 3.0,
175+
Double.NaN, 5.0, 6.0,
176+
7.0, 8.0, Double.NaN,
177+
)
178+
179+
dfWithNaN[0].rowStd().shouldBeNaN()
180+
dfWithNaN[1].rowStd().shouldBeNaN()
181+
dfWithNaN[2].rowStd().shouldBeNaN()
182+
183+
// With skipNaN=true, NaN values should be ignored
184+
dfWithNaN[0].rowStd(skipNaN = true) shouldBe 1.4142135623730951 // std of [1.0, 3.0]
185+
dfWithNaN[1].rowStd(skipNaN = true) shouldBe 0.7071067811865476 // std of [5.0, 6.0]
186+
dfWithNaN[2].rowStd(skipNaN = true) shouldBe 0.7071067811865476 // std of [7.0, 8.0]
187+
}
188+
189+
@[Test Suppress("ktlint:standard:argument-list-wrapping")]
190+
fun `dataframe std`() {
191+
val df = dataFrameOf(
192+
"a", "b", "c",
193+
)(
194+
1, 2, 3,
195+
4, 5, 6,
196+
7, 8, 9,
197+
)
198+
199+
// Get row with standard deviations for each column
200+
val stds = df.std()
201+
stds["a"] shouldBe 3.0
202+
stds["b"] shouldBe 3.0
203+
stds["c"] shouldBe 3.0
204+
205+
// Test std for specific columns
206+
val stdFor = df.stdFor("a", "c")
207+
stdFor["a"] shouldBe 3.0
208+
stdFor["c"] shouldBe 3.0
209+
}
210+
211+
@[Test Suppress("ktlint:standard:argument-list-wrapping")]
212+
fun `dataframe stdOf`() {
213+
val df = dataFrameOf(
214+
"a", "b", "c",
215+
)(
216+
1, 2, 3,
217+
4, 5, 6,
218+
7, 8, 9,
219+
)
220+
221+
// Calculate standard deviation of a + c for each row
222+
df.stdOf { "a"<Int>() + "c"<Int>() } shouldBe 6.0 // std of [4, 10, 16]
223+
}
224+
225+
@[Test Suppress("ktlint:standard:argument-list-wrapping")]
226+
fun `std with skipNaN for floating point numbers`() {
227+
// Test with Float.NaN values
228+
val floatWithNaN = columnOf(1.0f, 2.0f, Float.NaN, 4.0f, 5.0f)
229+
floatWithNaN.std().shouldBeNaN() // Default behavior: NaN propagates
230+
floatWithNaN.std(skipNaN = true) shouldBe 1.8257418583505538 // Skip NaN values
231+
232+
// Test with Double.NaN values
233+
val doubleWithNaN = columnOf(1.0, 2.0, Double.NaN, 4.0, 5.0)
234+
doubleWithNaN.std().shouldBeNaN() // Default behavior: NaN propagates
235+
doubleWithNaN.std(skipNaN = true) shouldBe 1.8257418583505538 // Skip NaN values
236+
237+
// Test with multiple NaN values in different positions
238+
val multipleNaN = columnOf(Float.NaN, 2.0f, Float.NaN, 4.0f, Float.NaN)
239+
multipleNaN.std().shouldBeNaN() // Default behavior: NaN propagates
240+
multipleNaN.std(skipNaN = true) shouldBe 1.4142135623730951 // Skip NaN values
241+
242+
// Test with all NaN values
243+
val allNaN = columnOf(Float.NaN, Float.NaN, Float.NaN)
244+
allNaN.std().shouldBeNaN() // All values are NaN, so result is NaN
245+
allNaN.std(skipNaN = true).shouldBeNaN() // With skipNaN=true and only NaN values, result should be NaN
246+
247+
// Test with DataFrame containing NaN values
248+
val dfWithNaN = dataFrameOf(
249+
"a", "b", "c",
250+
)(
251+
1.0, Double.NaN, 3.0,
252+
4.0, 5.0, Float.NaN,
253+
Double.NaN, 8.0, 9.0,
254+
)
255+
256+
// Test DataFrame std with NaN values
257+
val stdsWithNaN = dfWithNaN.std() // Default behavior
258+
(stdsWithNaN["a"] as Double).shouldBeNaN() // Contains NaN
259+
(stdsWithNaN["b"] as Double).shouldBeNaN() // Contains NaN
260+
(stdsWithNaN["c"] as Double).shouldBeNaN() // Contains NaN
261+
262+
// Test DataFrame std with skipNaN=true
263+
val stdsSkipNaN = dfWithNaN.std(skipNaN = true) // Skip NaN values
264+
stdsSkipNaN["a"] shouldBe 2.1213203435596424 // std of [1.0, 4.0]
265+
stdsSkipNaN["b"] shouldBe 2.1213203435596424 // std of [5.0, 8.0]
266+
stdsSkipNaN["c"] shouldBe 4.242640687119285 // std of [3.0, 9.0]
267+
268+
// Test stdFor with skipNaN
269+
val stdForWithNaN = dfWithNaN.stdFor("a", "c") // Default behavior
270+
(stdForWithNaN["a"] as Double).shouldBeNaN() // Contains NaN
271+
(stdForWithNaN["c"] as Double).shouldBeNaN() // Contains NaN
272+
273+
val stdForSkipNaN = dfWithNaN.stdFor("a", "c", skipNaN = true) // Skip NaN values
274+
stdForSkipNaN["a"] shouldBe 2.1213203435596424 // std of [1.0, 4.0]
275+
stdForSkipNaN["c"] shouldBe 4.242640687119285 // std of [3.0, 9.0]
276+
277+
// Test stdOf with transformation that might produce NaN values
278+
val dfForTransform = dataFrameOf(
279+
"a", "b",
280+
)(
281+
1.0, 0.0,
282+
4.0, 2.0,
283+
0.0, 0.0,
284+
)
285+
286+
// Division by zero produces NaN
287+
dfForTransform.stdOf {
288+
val b = "b"<Double>()
289+
if (b == 0.0) Double.NaN else "a"<Double>() / b
290+
}.shouldBeNaN() // Default behavior: NaN propagates
291+
292+
// Skip NaN values from division by zero
293+
dfForTransform.stdOf(skipNaN = true) {
294+
val b = "b"<Double>()
295+
if (b == 0.0) Double.NaN else "a"<Double>() / b
296+
}.shouldBeNaN() // Only 4.0/2.0 = 2.0 is valid, std of a single value is NaN
297+
}
298+
62299
@Test
63300
fun `std on empty or nullable column`() {
64301
val empty = DataColumn.createValueColumn("", emptyList<Nothing>(), nothingType(false))

0 commit comments

Comments
 (0)