Skip to content

Commit 2a40584

Browse files
committed
Add readArrow method and ArrowReader extension to allow loading a dataframe from an ArrowReader #528
1 parent 1ca037b commit 2a40584

File tree

4 files changed

+78
-19
lines changed

4 files changed

+78
-19
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReading.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.io
22

33
import org.apache.arrow.memory.RootAllocator
4+
import org.apache.arrow.vector.ipc.ArrowReader
45
import org.apache.commons.compress.utils.SeekableInMemoryByteChannel
56
import org.jetbrains.kotlinx.dataframe.AnyFrame
67
import org.jetbrains.kotlinx.dataframe.DataFrame
@@ -170,3 +171,18 @@ public fun DataFrame.Companion.readArrowFeather(
170171
} else {
171172
readArrowFeather(File(path), nullability)
172173
}
174+
175+
/**
176+
* Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [reader]
177+
*/
178+
public fun DataFrame.Companion.readArrow(
179+
reader: ArrowReader,
180+
nullability: NullabilityOptions = NullabilityOptions.Infer
181+
): AnyFrame = readArrowImpl(reader, nullability)
182+
183+
/**
184+
* Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [ArrowReader]
185+
*/
186+
public fun ArrowReader.toDataFrame(
187+
nullability: NullabilityOptions = NullabilityOptions.Infer
188+
): AnyFrame = DataFrame.Companion.readArrowImpl(this, nullability)

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.arrow.vector.VarCharVector
3232
import org.apache.arrow.vector.VectorSchemaRoot
3333
import org.apache.arrow.vector.complex.StructVector
3434
import org.apache.arrow.vector.ipc.ArrowFileReader
35+
import org.apache.arrow.vector.ipc.ArrowReader
3536
import org.apache.arrow.vector.ipc.ArrowStreamReader
3637
import org.apache.arrow.vector.types.pojo.Field
3738
import org.apache.arrow.vector.util.DateUtility
@@ -262,17 +263,7 @@ internal fun DataFrame.Companion.readArrowIPCImpl(
262263
allocator: RootAllocator = Allocator.ROOT,
263264
nullability: NullabilityOptions = NullabilityOptions.Infer,
264265
): AnyFrame {
265-
ArrowStreamReader(channel, allocator).use { reader ->
266-
val flattened = buildList {
267-
val root = reader.vectorSchemaRoot
268-
val schema = root.schema
269-
while (reader.loadNextBatch()) {
270-
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
271-
add(df)
272-
}
273-
}
274-
return flattened.concatKeepingSchema()
275-
}
266+
return readArrowImpl(ArrowStreamReader(channel, allocator), nullability)
276267
}
277268

278269
/**
@@ -283,14 +274,36 @@ internal fun DataFrame.Companion.readArrowFeatherImpl(
283274
allocator: RootAllocator = Allocator.ROOT,
284275
nullability: NullabilityOptions = NullabilityOptions.Infer,
285276
): AnyFrame {
286-
ArrowFileReader(channel, allocator).use { reader ->
277+
return readArrowImpl(ArrowFileReader(channel, allocator), nullability)
278+
}
279+
280+
/**
281+
* Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [reader]
282+
*/
283+
internal fun DataFrame.Companion.readArrowImpl(
284+
reader: ArrowReader,
285+
nullability: NullabilityOptions = NullabilityOptions.Infer
286+
): AnyFrame {
287+
reader.use {
287288
val flattened = buildList {
288-
reader.recordBlocks.forEach { block ->
289-
reader.loadRecordBatch(block)
290-
val root = reader.vectorSchemaRoot
291-
val schema = root.schema
292-
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
293-
add(df)
289+
when (reader) {
290+
is ArrowFileReader -> {
291+
reader.recordBlocks.forEach { block ->
292+
reader.loadRecordBatch(block)
293+
val root = reader.vectorSchemaRoot
294+
val schema = root.schema
295+
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
296+
add(df)
297+
}
298+
}
299+
is ArrowStreamReader -> {
300+
val root = reader.vectorSchemaRoot
301+
val schema = root.schema
302+
while (reader.loadNextBatch()) {
303+
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
304+
add(df)
305+
}
306+
}
294307
}
295308
}
296309
return flattened.concatKeepingSchema()

dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@ import org.apache.arrow.vector.TimeStampMilliVector
99
import org.apache.arrow.vector.TimeStampNanoVector
1010
import org.apache.arrow.vector.TimeStampSecVector
1111
import org.apache.arrow.vector.VectorSchemaRoot
12+
import org.apache.arrow.vector.ipc.ArrowFileReader
1213
import org.apache.arrow.vector.ipc.ArrowFileWriter
14+
import org.apache.arrow.vector.ipc.ArrowStreamReader
1315
import org.apache.arrow.vector.ipc.ArrowStreamWriter
1416
import org.apache.arrow.vector.types.FloatingPointPrecision
1517
import org.apache.arrow.vector.types.TimeUnit
1618
import org.apache.arrow.vector.types.pojo.ArrowType
1719
import org.apache.arrow.vector.types.pojo.Field
1820
import org.apache.arrow.vector.types.pojo.FieldType
1921
import org.apache.arrow.vector.types.pojo.Schema
22+
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
2023
import org.apache.arrow.vector.util.Text
2124
import org.jetbrains.kotlinx.dataframe.DataColumn
2225
import org.jetbrains.kotlinx.dataframe.DataFrame
@@ -32,6 +35,7 @@ import org.jetbrains.kotlinx.dataframe.api.remove
3235
import org.jetbrains.kotlinx.dataframe.api.toColumn
3336
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
3437
import org.junit.Test
38+
import java.io.ByteArrayInputStream
3539
import java.io.ByteArrayOutputStream
3640
import java.io.File
3741
import java.net.URL
@@ -553,4 +557,30 @@ internal class ArrowKtTest {
553557
}
554558
}
555559
}
560+
561+
@Test
562+
fun testArrowReaderExtension() {
563+
val dates = listOf(
564+
LocalDateTime.of(2023, 11, 23, 9, 30, 25),
565+
LocalDateTime.of(2015, 5, 25, 14, 20, 13),
566+
LocalDateTime.of(2013, 6, 19, 11, 20, 13),
567+
LocalDateTime.of(2000, 1, 1, 0, 0, 0)
568+
)
569+
570+
val expected = dataFrameOf(
571+
"string" to listOf("a", "b", "c", "d"),
572+
"int" to listOf(1, 2, 3, 4),
573+
"float" to listOf(1.0f, 2.0f, 3.0f, 4.0f),
574+
"double" to listOf(1.0, 2.0, 3.0, 4.0),
575+
"datetime" to dates
576+
)
577+
578+
val featherChannel = ByteArrayReadableSeekableByteChannel(expected.saveArrowFeatherToByteArray())
579+
val arrowFileReader = ArrowFileReader(featherChannel, RootAllocator())
580+
arrowFileReader.toDataFrame() shouldBe expected
581+
582+
val ipcInputStream = ByteArrayInputStream(expected.saveArrowIPCToByteArray())
583+
val arrowStreamReader = ArrowStreamReader(ipcInputStream, RootAllocator())
584+
arrowStreamReader.toDataFrame() shouldBe expected
585+
}
556586
}

docs/StardustDocs/topics/read.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ val df = DataFrame.readArrowFeather(file)
445445

446446
[`DataFrame`](DataFrame.md) supports reading [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format)
447447
and [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files)
448-
from raw Channel (ReadableByteChannel for streaming and SeekableByteChannel for random access), InputStream, File or ByteArray.
448+
from raw Channel (ReadableByteChannel for streaming and SeekableByteChannel for random access), ArrowReader, InputStream, File or ByteArray.
449449

450450
> If you use Java 9+, follow the [Apache Arrow Java compatibility](https://arrow.apache.org/docs/java/install.html#java-compatibility) guide.
451451
>

0 commit comments

Comments
 (0)