Skip to content

Commit 562073b

Browse files
KopilovKopilov
authored andcommitted
Explicit ArrowFormat parameter, public Channel reading
1 parent 0d3f633 commit 562073b

File tree

2 files changed

+92
-12
lines changed

2 files changed

+92
-12
lines changed

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ kotlinDatetime = "0.3.1"
1717
junit = "4.13.2"
1818
kotestAsserions = "4.6.3"
1919
jsoup = "1.14.3"
20-
arrow = "7.0.0"
20+
arrow = "8.0.0"
2121

2222
[libraries]
2323
ksp-gradle = { group = "com.google.devtools.ksp", name = "symbol-processing-gradle-plugin", version.ref = "ksp" }

src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.jetbrains.kotlinx.dataframe.io
22

33
import org.apache.arrow.memory.RootAllocator
44
import org.apache.arrow.vector.BigIntVector
5+
import org.apache.arrow.vector.BitVector
56
import org.apache.arrow.vector.Decimal256Vector
67
import org.apache.arrow.vector.DecimalVector
78
import org.apache.arrow.vector.DurationVector
@@ -27,6 +28,7 @@ import org.apache.arrow.vector.complex.StructVector
2728
import org.apache.arrow.vector.ipc.ArrowFileReader
2829
import org.apache.arrow.vector.ipc.ArrowStreamReader
2930
import org.apache.arrow.vector.types.pojo.Field
31+
import org.apache.commons.compress.utils.SeekableInMemoryByteChannel
3032
import org.jetbrains.kotlinx.dataframe.AnyBaseColumn
3133
import org.jetbrains.kotlinx.dataframe.AnyFrame
3234
import org.jetbrains.kotlinx.dataframe.DataColumn
@@ -53,7 +55,22 @@ internal object Allocator {
5355
}
5456
}
5557

56-
private fun readArrow(channel: ReadableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
58+
public enum class ArrowFormat() {
59+
/**
60+
* [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format)
61+
*/
62+
IPC,
63+
64+
/**
65+
* [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files)
66+
*/
67+
FEATHER
68+
}
69+
70+
/**
71+
* Read [ArrowFormat.IPC] data from existing [channel]
72+
*/
73+
public fun readArrowIPC(channel: ReadableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
5774
ArrowStreamReader(channel, allocator).use { reader ->
5875
val dfs = buildList {
5976
val root = reader.vectorSchemaRoot
@@ -67,7 +84,10 @@ private fun readArrow(channel: ReadableByteChannel, allocator: RootAllocator = A
6784
}
6885
}
6986

70-
private fun readArrow(channel: SeekableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
87+
/**
88+
* Read [ArrowFormat.FEATHER] data from existing [channel]
89+
*/
90+
public fun readArrowFeather(channel: SeekableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
7191
ArrowFileReader(channel, allocator).use { reader ->
7292
val dfs = buildList {
7393
reader.recordBlocks.forEach { block ->
@@ -82,6 +102,8 @@ private fun readArrow(channel: SeekableByteChannel, allocator: RootAllocator = A
82102
}
83103
}
84104

105+
private fun BitVector.values(range: IntRange): List<Boolean?> = range.map { getObject(it) }
106+
85107
private fun UInt1Vector.values(range: IntRange): List<Byte?> = range.map { getObject(it) }
86108
private fun UInt2Vector.values(range: IntRange): List<Char?> = range.map { getObject(it) }
87109
private fun UInt4Vector.values(range: IntRange): List<Long?> = range.map { getObjectNoOverflow(it) }
@@ -146,6 +168,7 @@ private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseColumn {
146168
is LargeVarCharVector -> vector.values(range).withType()
147169
is VarBinaryVector -> vector.values(range).withType()
148170
is LargeVarBinaryVector -> vector.values(range).withType()
171+
is BitVector -> vector.values(range).withType()
149172
is SmallIntVector -> vector.values(range).withType()
150173
is TinyIntVector -> vector.values(range).withType()
151174
is UInt1Vector -> vector.values(range).withType()
@@ -171,23 +194,80 @@ private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseColumn {
171194
return DataColumn.createValueColumn(field.name, list, type, Infer.Nulls)
172195
}
173196

174-
public fun DataFrame.Companion.readArrow(file: File): AnyFrame {
175-
return Files.newByteChannel(file.toPath()).use { readArrow(it) }
197+
// IPC reading block
198+
199+
private fun DataFrame.Companion.readArrowIPC(file: File): AnyFrame = Files.newByteChannel(file.toPath()).use { readArrowIPC(it) }
200+
201+
private fun DataFrame.Companion.readArrowIPC(byteArray: ByteArray): AnyFrame = SeekableInMemoryByteChannel(byteArray).use { readArrowIPC(it) }
202+
203+
private fun DataFrame.Companion.readArrowIPC(stream: InputStream): AnyFrame = Channels.newChannel(stream).use { readArrowIPC(it) }
204+
205+
private fun DataFrame.Companion.readArrowIPC(url: URL): AnyFrame =
206+
when {
207+
isFile(url) -> readArrowIPC(urlAsFile(url))
208+
isProtocolSupported(url) -> url.openStream().use { readArrowIPC(it) }
209+
else -> {
210+
throw IllegalArgumentException("Invalid protocol for url $url")
211+
}
212+
}
213+
214+
private fun DataFrame.Companion.readArrowIPC(path: String): AnyFrame = if (isURL(path)) {
215+
readArrowIPC(URL(path))
216+
} else {
217+
readArrowIPC(File(path))
176218
}
177219

178-
public fun DataFrame.Companion.readArrow(stream: InputStream): AnyFrame = Channels.newChannel(stream).use { readArrow(it) }
220+
// Feather reading block
179221

180-
public fun DataFrame.Companion.readArrow(url: URL): AnyFrame =
222+
private fun DataFrame.Companion.readArrowFeather(file: File): AnyFrame = Files.newByteChannel(file.toPath()).use { readArrowFeather(it) }
223+
224+
private fun DataFrame.Companion.readArrowFeather(byteArray: ByteArray): AnyFrame = SeekableInMemoryByteChannel(byteArray).use { readArrowFeather(it) }
225+
226+
private fun DataFrame.Companion.readArrowFeather(stream: InputStream): AnyFrame = readArrowFeather(stream.readAllBytes())
227+
228+
private fun DataFrame.Companion.readArrowFeather(url: URL): AnyFrame =
181229
when {
182-
isFile(url) -> readArrow(urlAsFile(url))
183-
isProtocolSupported(url) -> url.openStream().use { readArrow(it) }
230+
isFile(url) -> readArrowFeather(urlAsFile(url))
231+
isProtocolSupported(url) -> readArrowFeather(url.readBytes())
184232
else -> {
185233
throw IllegalArgumentException("Invalid protocol for url $url")
186234
}
187235
}
188236

189-
public fun DataFrame.Companion.readArrow(path: String): AnyFrame = if (isURL(path)) {
190-
readArrow(URL(path))
237+
private fun DataFrame.Companion.readArrowFeather(path: String): AnyFrame = if (isURL(path)) {
238+
readArrowFeather(URL(path))
191239
} else {
192-
readArrow(File(path))
240+
readArrowFeather(File(path))
193241
}
242+
243+
// Common reading block
244+
245+
public fun DataFrame.Companion.readArrow(file: File, format: ArrowFormat = ArrowFormat.FEATHER): AnyFrame =
246+
when (format) {
247+
ArrowFormat.IPC -> readArrowIPC(file)
248+
ArrowFormat.FEATHER -> readArrowFeather(file)
249+
}
250+
251+
public fun DataFrame.Companion.readArrow(byteArray: ByteArray, format: ArrowFormat = ArrowFormat.FEATHER): AnyFrame =
252+
when (format) {
253+
ArrowFormat.IPC -> readArrowIPC(byteArray)
254+
ArrowFormat.FEATHER -> readArrowFeather(byteArray)
255+
}
256+
257+
public fun DataFrame.Companion.readArrow(stream: InputStream, format: ArrowFormat = ArrowFormat.IPC): AnyFrame =
258+
when (format) {
259+
ArrowFormat.IPC -> readArrowIPC(stream)
260+
ArrowFormat.FEATHER -> readArrowFeather(stream)
261+
}
262+
263+
public fun DataFrame.Companion.readArrow(url: URL, format: ArrowFormat = ArrowFormat.IPC): AnyFrame =
264+
when (format) {
265+
ArrowFormat.IPC -> readArrowIPC(url)
266+
ArrowFormat.FEATHER -> readArrowFeather(url)
267+
}
268+
269+
public fun DataFrame.Companion.readArrow(path: String, format: ArrowFormat = ArrowFormat.IPC): AnyFrame =
270+
when (format) {
271+
ArrowFormat.IPC -> readArrowIPC(path)
272+
ArrowFormat.FEATHER -> readArrowFeather(path)
273+
}

0 commit comments

Comments
 (0)