diff --git a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriterImpl.kt b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriterImpl.kt index 194e5dec3f..64d978c81b 100644 --- a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriterImpl.kt +++ b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriterImpl.kt @@ -26,6 +26,7 @@ import org.apache.arrow.vector.TinyIntVector import org.apache.arrow.vector.VarCharVector import org.apache.arrow.vector.VariableWidthVector import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.complex.StructVector import org.apache.arrow.vector.types.DateUnit import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.pojo.ArrowType @@ -49,8 +50,10 @@ import org.jetbrains.kotlinx.dataframe.api.convertToShort import org.jetbrains.kotlinx.dataframe.api.convertToString import org.jetbrains.kotlinx.dataframe.api.forEachIndexed import org.jetbrains.kotlinx.dataframe.api.map +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException +import org.jetbrains.kotlinx.dataframe.indices import org.jetbrains.kotlinx.dataframe.name import org.jetbrains.kotlinx.dataframe.values import kotlin.reflect.full.isSubtypeOf @@ -72,7 +75,15 @@ internal class ArrowWriterImpl( private fun allocateVector(vector: FieldVector, size: Int, totalBytes: Long? = null) { when (vector) { is FixedWidthVector -> vector.allocateNew(size) + is VariableWidthVector -> totalBytes?.let { vector.allocateNew(it, size) } ?: vector.allocateNew(size) + + is StructVector -> { + vector.childrenFromFields.forEach { child -> + allocateVector(child, size) + } + } + else -> throw IllegalArgumentException("Can not allocate ${vector.javaClass.canonicalName}") } } @@ -138,6 +149,8 @@ internal class ArrowWriterImpl( is ArrowType.Time -> column.convertToLocalTime() + is ArrowType.Struct -> column + else -> throw NotImplementedError( "Saving ${targetFieldType.javaClass.canonicalName} is currently not implemented", @@ -277,6 +290,18 @@ internal class ArrowWriterImpl( } ?: vector.setNull(i) } + is StructVector -> { + require(column is ColumnGroup<*>) { + "StructVector expects ColumnGroup, but got ${column::class.simpleName}" + } + + column.columns().forEach { childColumn -> + infillVector(vector.getChild(childColumn.name()), childColumn) + } + + column.indices.forEach { i -> vector.setIndexDefined(i) } + } + else -> { // TODO implement other vector types from [readField] (VarBinaryVector, UIntVector, DurationVector, StructVector) and may be others (ListVector, FixedSizeListVector etc) throw NotImplementedError("Saving to ${vector.javaClass.canonicalName} is currently not implemented") diff --git a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt index 5ba09a7598..d982e6256f 100644 --- a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt +++ b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt @@ -19,6 +19,7 @@ import org.apache.arrow.vector.DateMilliVector import org.apache.arrow.vector.Decimal256Vector import org.apache.arrow.vector.DecimalVector import org.apache.arrow.vector.DurationVector +import org.apache.arrow.vector.FieldVector import org.apache.arrow.vector.Float4Vector import org.apache.arrow.vector.Float8Vector import org.apache.arrow.vector.IntVector @@ -293,10 +294,16 @@ private fun List.withTypeNullable( return this to nothingType(nullable) } -private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol { +private fun readField(vector: FieldVector, field: Field, nullability: NullabilityOptions): AnyBaseCol { try { - val range = 0 until root.rowCount - val (list, type) = when (val vector = root.getVector(field)) { + val range = 0 until vector.valueCount + if (vector is StructVector) { + val columns = field.children.map { childField -> + readField(vector.getChild(childField.name), childField, nullability) + } + return DataColumn.createColumnGroup(field.name, columns.toDataFrame()) + } + val (list, type) = when (vector) { is VarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) is LargeVarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) @@ -357,8 +364,6 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi is TimeStampSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) - is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) - is NullVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) else -> { @@ -371,6 +376,9 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi } } +private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol = + readField(root.getVector(field), field, nullability) + /** * Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel] */ diff --git a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowTypesMatching.kt b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowTypesMatching.kt index 1e337d8d2a..2d3fa13010 100644 --- a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowTypesMatching.kt +++ b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowTypesMatching.kt @@ -11,6 +11,7 @@ import org.apache.arrow.vector.types.pojo.Field import org.apache.arrow.vector.types.pojo.FieldType import org.apache.arrow.vector.types.pojo.Schema import org.jetbrains.kotlinx.dataframe.AnyCol +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.typeClass import kotlin.reflect.full.isSubtypeOf import kotlin.reflect.typeOf @@ -27,6 +28,15 @@ public fun AnyCol.toArrowField(mismatchSubscriber: (ConvertingMismatch) -> Unit val columnType = column.type() val nullable = columnType.isMarkedNullable return when { + column is ColumnGroup<*> -> { + val childFields = column.columns().map { it.toArrowField(mismatchSubscriber) } + Field( + column.name(), + FieldType(nullable, ArrowType.Struct(), null), + childFields, + ) + } + columnType.isSubtypeOf(typeOf()) -> Field( column.name(), diff --git a/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt b/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt index 679aabae49..9efcc77673 100644 --- a/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt +++ b/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt @@ -25,7 +25,6 @@ import org.apache.arrow.vector.types.pojo.Field import org.apache.arrow.vector.types.pojo.FieldType import org.apache.arrow.vector.types.pojo.Schema import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel -import org.apache.arrow.vector.util.Text import org.duckdb.DuckDBConnection import org.duckdb.DuckDBResultSet import org.jetbrains.kotlinx.dataframe.AnyFrame @@ -39,7 +38,7 @@ import org.jetbrains.kotlinx.dataframe.api.dataFrameOf import org.jetbrains.kotlinx.dataframe.api.map import org.jetbrains.kotlinx.dataframe.api.pathOf import org.jetbrains.kotlinx.dataframe.api.remove -import org.jetbrains.kotlinx.dataframe.api.toColumn +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException import org.junit.Assert import org.junit.Test @@ -68,13 +67,11 @@ internal class ArrowKtTest { val df = DataFrame.readArrowFeather(feather) val a by columnOf("one") val b by columnOf(2.0) - val c by listOf( - mapOf( - "c1" to Text("inner"), - "c2" to 4.0, - "c3" to 50.0, - ) as Map, - ).toColumn() + val c by columnOf( + "c1" to columnOf("inner"), + "c2" to columnOf(4.0), + "c3" to columnOf(50.0), + ) val d by columnOf("four") val expected = dataFrameOf(a, b, c, d) df shouldBe expected @@ -728,4 +725,89 @@ internal class ArrowKtTest { dataFrame.rowsCount() shouldBe 900 } + + @Test + fun testColumnGroupRoundtrip() { + val original = dataFrameOf( + "outer" to columnOf("x", "y", "z"), + "inner" to columnOf( + "nested1" to columnOf("a", "b", "c"), + "nested2" to columnOf(1, 2, 3), + ), + ) + + val featherBytes = original.saveArrowFeatherToByteArray() + val fromFeather = DataFrame.readArrowFeather(featherBytes) + fromFeather shouldBe original + + val ipcBytes = original.saveArrowIPCToByteArray() + val fromIpc = DataFrame.readArrowIPC(ipcBytes) + fromIpc shouldBe original + } + + @Test + fun testNestedColumnGroupRoundtrip() { + val deeplyNested by columnOf( + "level2" to columnOf( + "level3" to columnOf(1, 2, 3), + ), + ) + val original = dataFrameOf(deeplyNested) + + val bytes = original.saveArrowFeatherToByteArray() + val restored = DataFrame.readArrowFeather(bytes) + + restored shouldBe original + } + + @Test + fun testColumnGroupWithNulls() { + val group by columnOf( + "a" to columnOf("x", null, "z"), + "b" to columnOf(1, 2, null), + ) + val original = dataFrameOf(group) + + val bytes = original.saveArrowFeatherToByteArray() + val restored = DataFrame.readArrowFeather(bytes) + + restored shouldBe original + } + + @Test + fun testReadParquetWithNestedStruct() { + val resourceUrl = testResource("books.parquet") + val resourcePath = resourceUrl.toURI().toPath() + + val df = DataFrame.readParquet(resourcePath) + + df.columnNames() shouldBe listOf("id", "title", "author", "genre", "publisher") + + val authorGroup = df["author"] as ColumnGroup<*> + authorGroup.columnNames() shouldBe listOf("id", "firstName", "lastName") + + df["id"].type() shouldBe typeOf() + df["title"].type() shouldBe typeOf() + df["genre"].type() shouldBe typeOf() + df["publisher"].type() shouldBe typeOf() + authorGroup["id"].type() shouldBe typeOf() + authorGroup["firstName"].type() shouldBe typeOf() + authorGroup["lastName"].type() shouldBe typeOf() + } + + @Test + fun testParquetNestedStructRoundtrip() { + val resourceUrl = testResource("books.parquet") + val resourcePath = resourceUrl.toURI().toPath() + + val original = DataFrame.readParquet(resourcePath) + + val featherBytes = original.saveArrowFeatherToByteArray() + val fromFeather = DataFrame.readArrowFeather(featherBytes) + fromFeather shouldBe original + + val ipcBytes = original.saveArrowIPCToByteArray() + val fromIpc = DataFrame.readArrowIPC(ipcBytes) + fromIpc shouldBe original + } } diff --git a/dataframe-arrow/src/test/resources/books.parquet b/dataframe-arrow/src/test/resources/books.parquet new file mode 100644 index 0000000000..0fea3a95fd Binary files /dev/null and b/dataframe-arrow/src/test/resources/books.parquet differ