@@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.io
3
3
import org.apache.arrow.memory.RootAllocator
4
4
import org.apache.arrow.vector.BigIntVector
5
5
import org.apache.arrow.vector.BitVector
6
+ import org.apache.arrow.vector.DateDayVector
7
+ import org.apache.arrow.vector.DateMilliVector
6
8
import org.apache.arrow.vector.Decimal256Vector
7
9
import org.apache.arrow.vector.DecimalVector
8
10
import org.apache.arrow.vector.DurationVector
@@ -28,16 +30,21 @@ import org.apache.arrow.vector.complex.StructVector
28
30
import org.apache.arrow.vector.ipc.ArrowFileReader
29
31
import org.apache.arrow.vector.ipc.ArrowStreamReader
30
32
import org.apache.arrow.vector.types.pojo.Field
33
+ import org.apache.arrow.vector.util.DateUtility
31
34
import org.apache.commons.compress.utils.SeekableInMemoryByteChannel
32
35
import org.jetbrains.kotlinx.dataframe.AnyBaseCol
33
36
import org.jetbrains.kotlinx.dataframe.AnyFrame
34
37
import org.jetbrains.kotlinx.dataframe.DataColumn
35
38
import org.jetbrains.kotlinx.dataframe.DataFrame
36
39
import org.jetbrains.kotlinx.dataframe.api.Infer
37
- import org.jetbrains.kotlinx.dataframe.api.concat
40
+ import org.jetbrains.kotlinx.dataframe.api.cast
41
+ import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
42
+ import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
43
+ import org.jetbrains.kotlinx.dataframe.api.getColumn
38
44
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
39
45
import org.jetbrains.kotlinx.dataframe.codeGen.AbstractDefaultReadMethod
40
46
import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadDfMethod
47
+ import org.jetbrains.kotlinx.dataframe.impl.asList
41
48
import java.io.File
42
49
import java.io.InputStream
43
50
import java.math.BigDecimal
@@ -48,6 +55,7 @@ import java.nio.channels.ReadableByteChannel
48
55
import java.nio.channels.SeekableByteChannel
49
56
import java.nio.file.Files
50
57
import java.time.Duration
58
+ import java.time.LocalDate
51
59
import java.time.LocalDateTime
52
60
import kotlin.reflect.typeOf
53
61
@@ -75,6 +83,25 @@ internal object Allocator {
75
83
}
76
84
}
77
85
86
+ /* *
87
+ * same as [Iterable<DataFrame<T>>.concat()] without internal type guessing (all batches should have the same schema)
88
+ */
89
+ internal fun <T > Iterable<DataFrame<T>>.concatKeepingSchema (): DataFrame <T > {
90
+ val dataFrames = asList()
91
+ when (dataFrames.size) {
92
+ 0 -> return emptyDataFrame()
93
+ 1 -> return dataFrames[0 ]
94
+ }
95
+
96
+ val columnNames = dataFrames.first().columnNames()
97
+
98
+ val columns = columnNames.map { name ->
99
+ val values = dataFrames.flatMap { it.getColumn(name).values() }
100
+ DataColumn .createValueColumn(name, values, dataFrames.first().getColumn(name).type())
101
+ }
102
+ return dataFrameOf(columns).cast()
103
+ }
104
+
78
105
/* *
79
106
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel]
80
107
*/
@@ -88,7 +115,7 @@ public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, alloca
88
115
add(df)
89
116
}
90
117
}
91
- return dfs.concat ()
118
+ return dfs.concatKeepingSchema ()
92
119
}
93
120
}
94
121
@@ -106,7 +133,7 @@ public fun DataFrame.Companion.readArrowFeather(channel: SeekableByteChannel, al
106
133
add(df)
107
134
}
108
135
}
109
- return dfs.concat ()
136
+ return dfs.concatKeepingSchema ()
110
137
}
111
138
}
112
139
@@ -129,6 +156,11 @@ private fun Float4Vector.values(range: IntRange): List<Float?> = range.map { get
129
156
private fun Float8Vector.values (range : IntRange ): List <Double ?> = range.map { getObject(it) }
130
157
131
158
private fun DurationVector.values (range : IntRange ): List <Duration ?> = range.map { getObject(it) }
159
+ private fun DateDayVector.values (range : IntRange ): List <LocalDate ?> = range.map {
160
+ DateUtility .getLocalDateTimeFromEpochMilli(getObject(it).toLong() * DateUtility .daysToStandardMillis).toLocalDate()
161
+ }
162
+ private fun DateMilliVector.values (range : IntRange ): List <LocalDateTime ?> = range.map { getObject(it) }
163
+
132
164
private fun TimeNanoVector.values (range : IntRange ): List <Long ?> = range.map { getObject(it) }
133
165
private fun TimeMicroVector.values (range : IntRange ): List <Long ?> = range.map { getObject(it) }
134
166
private fun TimeMilliVector.values (range : IntRange ): List <LocalDateTime ?> = range.map { getObject(it) }
@@ -190,6 +222,8 @@ private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol {
190
222
is Float8Vector -> vector.values(range).withType()
191
223
is Float4Vector -> vector.values(range).withType()
192
224
is DurationVector -> vector.values(range).withType()
225
+ is DateDayVector -> vector.values(range).withType()
226
+ is DateMilliVector -> vector.values(range).withType()
193
227
is TimeNanoVector -> vector.values(range).withType()
194
228
is TimeMicroVector -> vector.values(range).withType()
195
229
is TimeMilliVector -> vector.values(range).withType()
@@ -199,7 +233,7 @@ private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol {
199
233
TODO (" not fully implemented" )
200
234
}
201
235
}
202
- return DataColumn .createValueColumn(field.name, list, type, Infer .Nulls )
236
+ return DataColumn .createValueColumn(field.name, list, type, Infer .None )
203
237
}
204
238
205
239
// IPC reading block
0 commit comments