Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.arrow.vector.TinyIntVector
import org.apache.arrow.vector.VarCharVector
import org.apache.arrow.vector.VariableWidthVector
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.complex.StructVector
import org.apache.arrow.vector.types.DateUnit
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.ArrowType
Expand All @@ -49,8 +50,10 @@ import org.jetbrains.kotlinx.dataframe.api.convertToShort
import org.jetbrains.kotlinx.dataframe.api.convertToString
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
import org.jetbrains.kotlinx.dataframe.api.map
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
import org.jetbrains.kotlinx.dataframe.indices
import org.jetbrains.kotlinx.dataframe.name
import org.jetbrains.kotlinx.dataframe.values
import kotlin.reflect.full.isSubtypeOf
Expand All @@ -72,7 +75,15 @@ internal class ArrowWriterImpl(
private fun allocateVector(vector: FieldVector, size: Int, totalBytes: Long? = null) {
when (vector) {
is FixedWidthVector -> vector.allocateNew(size)

is VariableWidthVector -> totalBytes?.let { vector.allocateNew(it, size) } ?: vector.allocateNew(size)

is StructVector -> {
vector.childrenFromFields.forEach { child ->
allocateVector(child, size)
}
}

else -> throw IllegalArgumentException("Can not allocate ${vector.javaClass.canonicalName}")
}
}
Expand Down Expand Up @@ -138,6 +149,8 @@ internal class ArrowWriterImpl(

is ArrowType.Time -> column.convertToLocalTime()

is ArrowType.Struct -> column

else ->
throw NotImplementedError(
"Saving ${targetFieldType.javaClass.canonicalName} is currently not implemented",
Expand Down Expand Up @@ -277,6 +290,18 @@ internal class ArrowWriterImpl(
} ?: vector.setNull(i)
}

is StructVector -> {
require(column is ColumnGroup<*>) {
"StructVector expects ColumnGroup, but got ${column::class.simpleName}"
}

column.columns().forEach { childColumn ->
infillVector(vector.getChild(childColumn.name()), childColumn)
}

column.indices.forEach { i -> vector.setIndexDefined(i) }
}

else -> {
// TODO implement other vector types from [readField] (VarBinaryVector, UIntVector, DurationVector, StructVector) and may be others (ListVector, FixedSizeListVector etc)
throw NotImplementedError("Saving to ${vector.javaClass.canonicalName} is currently not implemented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.apache.arrow.vector.DateMilliVector
import org.apache.arrow.vector.Decimal256Vector
import org.apache.arrow.vector.DecimalVector
import org.apache.arrow.vector.DurationVector
import org.apache.arrow.vector.FieldVector
import org.apache.arrow.vector.Float4Vector
import org.apache.arrow.vector.Float8Vector
import org.apache.arrow.vector.IntVector
Expand Down Expand Up @@ -293,10 +294,16 @@ private fun List<Nothing?>.withTypeNullable(
return this to nothingType(nullable)
}

private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol {
private fun readField(vector: FieldVector, field: Field, nullability: NullabilityOptions): AnyBaseCol {
try {
val range = 0 until root.rowCount
val (list, type) = when (val vector = root.getVector(field)) {
val range = 0 until vector.valueCount
if (vector is StructVector) {
val columns = field.children.map { childField ->
readField(vector.getChild(childField.name), childField, nullability)
}
return DataColumn.createColumnGroup(field.name, columns.toDataFrame())
}
val (list, type) = when (vector) {
is VarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)

is LargeVarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
Expand Down Expand Up @@ -357,8 +364,6 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi

is TimeStampSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)

is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)

is NullVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)

else -> {
Expand All @@ -371,6 +376,9 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi
}
}

private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol =
readField(root.getVector(field), field, nullability)

/**
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.FieldType
import org.apache.arrow.vector.types.pojo.Schema
import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.typeClass
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.typeOf
Expand All @@ -27,6 +28,15 @@ public fun AnyCol.toArrowField(mismatchSubscriber: (ConvertingMismatch) -> Unit
val columnType = column.type()
val nullable = columnType.isMarkedNullable
return when {
column is ColumnGroup<*> -> {
val childFields = column.columns().map { it.toArrowField(mismatchSubscriber) }
Field(
column.name(),
FieldType(nullable, ArrowType.Struct(), null),
childFields,
)
}

columnType.isSubtypeOf(typeOf<String?>()) ->
Field(
column.name(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.FieldType
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.arrow.vector.util.Text
import org.duckdb.DuckDBConnection
import org.duckdb.DuckDBResultSet
import org.jetbrains.kotlinx.dataframe.AnyFrame
Expand All @@ -39,7 +38,7 @@ import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.map
import org.jetbrains.kotlinx.dataframe.api.pathOf
import org.jetbrains.kotlinx.dataframe.api.remove
import org.jetbrains.kotlinx.dataframe.api.toColumn
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
import org.junit.Assert
import org.junit.Test
Expand Down Expand Up @@ -68,13 +67,11 @@ internal class ArrowKtTest {
val df = DataFrame.readArrowFeather(feather)
val a by columnOf("one")
val b by columnOf(2.0)
val c by listOf(
mapOf(
"c1" to Text("inner"),
"c2" to 4.0,
"c3" to 50.0,
) as Map<String, Any?>,
).toColumn()
val c by columnOf(
"c1" to columnOf("inner"),
"c2" to columnOf(4.0),
"c3" to columnOf(50.0),
)
val d by columnOf("four")
val expected = dataFrameOf(a, b, c, d)
df shouldBe expected
Expand Down Expand Up @@ -728,4 +725,89 @@ internal class ArrowKtTest {

dataFrame.rowsCount() shouldBe 900
}

@Test
fun testColumnGroupRoundtrip() {
val original = dataFrameOf(
"outer" to columnOf("x", "y", "z"),
"inner" to columnOf(
"nested1" to columnOf("a", "b", "c"),
"nested2" to columnOf(1, 2, 3),
),
)

val featherBytes = original.saveArrowFeatherToByteArray()
val fromFeather = DataFrame.readArrowFeather(featherBytes)
fromFeather shouldBe original

val ipcBytes = original.saveArrowIPCToByteArray()
val fromIpc = DataFrame.readArrowIPC(ipcBytes)
fromIpc shouldBe original
}

@Test
fun testNestedColumnGroupRoundtrip() {
val deeplyNested by columnOf(
"level2" to columnOf(
"level3" to columnOf(1, 2, 3),
),
)
val original = dataFrameOf(deeplyNested)

val bytes = original.saveArrowFeatherToByteArray()
val restored = DataFrame.readArrowFeather(bytes)

restored shouldBe original
}

@Test
fun testColumnGroupWithNulls() {
val group by columnOf(
"a" to columnOf("x", null, "z"),
"b" to columnOf(1, 2, null),
)
val original = dataFrameOf(group)

val bytes = original.saveArrowFeatherToByteArray()
val restored = DataFrame.readArrowFeather(bytes)

restored shouldBe original
}

@Test
fun testReadParquetWithNestedStruct() {
val resourceUrl = testResource("books.parquet")
val resourcePath = resourceUrl.toURI().toPath()

val df = DataFrame.readParquet(resourcePath)

df.columnNames() shouldBe listOf("id", "title", "author", "genre", "publisher")

val authorGroup = df["author"] as ColumnGroup<*>
authorGroup.columnNames() shouldBe listOf("id", "firstName", "lastName")

df["id"].type() shouldBe typeOf<Int>()
df["title"].type() shouldBe typeOf<String>()
df["genre"].type() shouldBe typeOf<String>()
df["publisher"].type() shouldBe typeOf<String>()
authorGroup["id"].type() shouldBe typeOf<Int>()
authorGroup["firstName"].type() shouldBe typeOf<String>()
authorGroup["lastName"].type() shouldBe typeOf<String>()
}

@Test
fun testParquetNestedStructRoundtrip() {
val resourceUrl = testResource("books.parquet")
val resourcePath = resourceUrl.toURI().toPath()

val original = DataFrame.readParquet(resourcePath)

val featherBytes = original.saveArrowFeatherToByteArray()
val fromFeather = DataFrame.readArrowFeather(featherBytes)
fromFeather shouldBe original

val ipcBytes = original.saveArrowIPCToByteArray()
val fromIpc = DataFrame.readArrowIPC(ipcBytes)
fromIpc shouldBe original
}
}
Binary file added dataframe-arrow/src/test/resources/books.parquet
Binary file not shown.