Skip to content

Commit fa5b956

Browse files
committed
New test draft, keep nullables in schema
1 parent 4fc93d7 commit fa5b956

File tree

6 files changed

+100
-8
lines changed

6 files changed

+100
-8
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.io
33
import org.apache.arrow.memory.RootAllocator
44
import org.apache.arrow.vector.BigIntVector
55
import org.apache.arrow.vector.BitVector
6+
import org.apache.arrow.vector.DateDayVector
7+
import org.apache.arrow.vector.DateMilliVector
68
import org.apache.arrow.vector.Decimal256Vector
79
import org.apache.arrow.vector.DecimalVector
810
import org.apache.arrow.vector.DurationVector
@@ -28,16 +30,21 @@ import org.apache.arrow.vector.complex.StructVector
2830
import org.apache.arrow.vector.ipc.ArrowFileReader
2931
import org.apache.arrow.vector.ipc.ArrowStreamReader
3032
import org.apache.arrow.vector.types.pojo.Field
33+
import org.apache.arrow.vector.util.DateUtility
3134
import org.apache.commons.compress.utils.SeekableInMemoryByteChannel
3235
import org.jetbrains.kotlinx.dataframe.AnyBaseCol
3336
import org.jetbrains.kotlinx.dataframe.AnyFrame
3437
import org.jetbrains.kotlinx.dataframe.DataColumn
3538
import org.jetbrains.kotlinx.dataframe.DataFrame
3639
import org.jetbrains.kotlinx.dataframe.api.Infer
37-
import org.jetbrains.kotlinx.dataframe.api.concat
40+
import org.jetbrains.kotlinx.dataframe.api.cast
41+
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
42+
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
43+
import org.jetbrains.kotlinx.dataframe.api.getColumn
3844
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
3945
import org.jetbrains.kotlinx.dataframe.codeGen.AbstractDefaultReadMethod
4046
import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadDfMethod
47+
import org.jetbrains.kotlinx.dataframe.impl.asList
4148
import java.io.File
4249
import java.io.InputStream
4350
import java.math.BigDecimal
@@ -48,6 +55,7 @@ import java.nio.channels.ReadableByteChannel
4855
import java.nio.channels.SeekableByteChannel
4956
import java.nio.file.Files
5057
import java.time.Duration
58+
import java.time.LocalDate
5159
import java.time.LocalDateTime
5260
import kotlin.reflect.typeOf
5361

@@ -75,6 +83,25 @@ internal object Allocator {
7583
}
7684
}
7785

86+
/**
87+
* same as [Iterable<DataFrame<T>>.concat()] without internal type guessing (all batches should have the same schema)
88+
*/
89+
internal fun <T> Iterable<DataFrame<T>>.concatKeepingSchema(): DataFrame<T> {
90+
val dataFrames = asList()
91+
when (dataFrames.size) {
92+
0 -> return emptyDataFrame()
93+
1 -> return dataFrames[0]
94+
}
95+
96+
val columnNames = dataFrames.first().columnNames()
97+
98+
val columns = columnNames.map { name ->
99+
val values = dataFrames.flatMap { it.getColumn(name).values() }
100+
DataColumn.createValueColumn(name, values, dataFrames.first().getColumn(name).type())
101+
}
102+
return dataFrameOf(columns).cast()
103+
}
104+
78105
/**
79106
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel]
80107
*/
@@ -88,7 +115,7 @@ public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, alloca
88115
add(df)
89116
}
90117
}
91-
return dfs.concat()
118+
return dfs.concatKeepingSchema()
92119
}
93120
}
94121

@@ -106,7 +133,7 @@ public fun DataFrame.Companion.readArrowFeather(channel: SeekableByteChannel, al
106133
add(df)
107134
}
108135
}
109-
return dfs.concat()
136+
return dfs.concatKeepingSchema()
110137
}
111138
}
112139

@@ -129,6 +156,11 @@ private fun Float4Vector.values(range: IntRange): List<Float?> = range.map { get
129156
private fun Float8Vector.values(range: IntRange): List<Double?> = range.map { getObject(it) }
130157

131158
private fun DurationVector.values(range: IntRange): List<Duration?> = range.map { getObject(it) }
159+
private fun DateDayVector.values(range: IntRange): List<LocalDate?> = range.map {
160+
DateUtility.getLocalDateTimeFromEpochMilli(getObject(it).toLong() * DateUtility.daysToStandardMillis).toLocalDate()
161+
}
162+
private fun DateMilliVector.values(range: IntRange): List<LocalDateTime?> = range.map { getObject(it) }
163+
132164
private fun TimeNanoVector.values(range: IntRange): List<Long?> = range.map { getObject(it) }
133165
private fun TimeMicroVector.values(range: IntRange): List<Long?> = range.map { getObject(it) }
134166
private fun TimeMilliVector.values(range: IntRange): List<LocalDateTime?> = range.map { getObject(it) }
@@ -190,6 +222,8 @@ private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol {
190222
is Float8Vector -> vector.values(range).withType()
191223
is Float4Vector -> vector.values(range).withType()
192224
is DurationVector -> vector.values(range).withType()
225+
is DateDayVector -> vector.values(range).withType()
226+
is DateMilliVector -> vector.values(range).withType()
193227
is TimeNanoVector -> vector.values(range).withType()
194228
is TimeMicroVector -> vector.values(range).withType()
195229
is TimeMilliVector -> vector.values(range).withType()
@@ -199,7 +233,7 @@ private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol {
199233
TODO("not fully implemented")
200234
}
201235
}
202-
return DataColumn.createValueColumn(field.name, list, type, Infer.Nulls)
236+
return DataColumn.createValueColumn(field.name, list, type, Infer.None)
203237
}
204238

205239
// IPC reading block
Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,44 @@
11
import io.kotest.matchers.shouldBe
22
import org.apache.arrow.vector.util.Text
3+
import org.jetbrains.kotlinx.dataframe.DataColumn
34
import org.jetbrains.kotlinx.dataframe.DataFrame
45
import org.jetbrains.kotlinx.dataframe.api.columnOf
56
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
67
import org.jetbrains.kotlinx.dataframe.api.toColumn
78
import org.jetbrains.kotlinx.dataframe.io.readArrowFeather
9+
import org.jetbrains.kotlinx.dataframe.io.readArrowIPC
810
import org.junit.Test
911
import java.net.URL
12+
import kotlin.reflect.typeOf
1013

1114
internal class ArrowKtTest {
1215

1316
fun testResource(resourcePath: String): URL = ArrowKtTest::class.java.classLoader.getResource(resourcePath)!!
1417

1518
fun testArrowFeather(name: String) = testResource("$name.feather")
19+
fun testArrowIPC(name: String) = testResource("$name.ipc")
1620

1721
@Test
1822
fun testReadingFromFile() {
1923
val feather = testArrowFeather("data-arrow_2.0.0_uncompressed")
2024
val df = DataFrame.readArrowFeather(feather)
21-
val a by columnOf("one")
22-
val b by columnOf(2.0)
25+
val a by listOf("one" as String?).toColumn()
26+
val b by listOf(2.0 as Double?).toColumn()
2327
val c by listOf(
2428
mapOf(
2529
"c1" to Text("inner"),
2630
"c2" to 4.0,
2731
"c3" to 50.0
28-
) as Map<String, Any?>
32+
) as Map<String, Any?>?
2933
).toColumn()
30-
val d by columnOf("four")
34+
val d by listOf("four" as String?).toColumn()
3135
val expected = dataFrameOf(a, b, c, d)
3236
df shouldBe expected
3337
}
38+
39+
@Test
40+
fun testReadingAllTypesAsEstimated() {
41+
assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test.arrow")))
42+
assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test.arrow")))
43+
}
3444
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import io.kotest.matchers.shouldBe
2+
import org.jetbrains.kotlinx.dataframe.AnyFrame
3+
import org.jetbrains.kotlinx.dataframe.DataColumn
4+
import org.jetbrains.kotlinx.dataframe.api.forEach
5+
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
6+
import kotlin.reflect.typeOf
7+
8+
/**
9+
* Assert that we have got the same data that was originally saved on example creation.
10+
*/
11+
internal fun assertEstimations(exampleFrame: AnyFrame) {
12+
/**
13+
* In [exampleFrame] we get two concatenated batches. To assert the estimations, we should transform frame row number to batch row number
14+
*/
15+
fun iBatch(iFrame: Int): Int {
16+
val firstBatchSize = 100;
17+
return if (iFrame < firstBatchSize) iFrame else iFrame - firstBatchSize
18+
}
19+
val asciiStringCol = exampleFrame["asciiString"] as DataColumn<String?>
20+
asciiStringCol.type() shouldBe typeOf<String?>()
21+
asciiStringCol.forEachIndexed { i, element ->
22+
element shouldBe "Test Example ${iBatch(i)}"
23+
}
24+
25+
val utf8StringCol = exampleFrame["utf8String"]
26+
val largeStringCol = exampleFrame["largeString"]
27+
28+
val booleanCol = exampleFrame["boolean"]
29+
30+
val byteCol = exampleFrame["byte"]
31+
val shortCol = exampleFrame["short"]
32+
val intCol = exampleFrame["int"]
33+
val longIntCol = exampleFrame["longInt"]
34+
35+
val unsignedByteCol = exampleFrame["unsigned_byte"]
36+
val unsignedShortCol = exampleFrame["unsigned_short"]
37+
val unsignedIntCol = exampleFrame["unsigned_int"]
38+
val unsignedLongIntCol = exampleFrame["unsigned_longInt"]
39+
40+
val dateCol = exampleFrame["date32"]
41+
val datetimeCol = exampleFrame["date64"]
42+
43+
val timeSecCol = exampleFrame["time32_seconds"]
44+
val timeMilliCol = exampleFrame["time32_milli"]
45+
46+
val timeMicroCol = exampleFrame["time64_micro"]
47+
val timeNanoCol = exampleFrame["time64_nano"]
48+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)