Skip to content

Commit d9f9569

Browse files
committed
Support writing ColumnGroup into Arrow as StructVector
1 parent 7eac71f commit d9f9569

File tree

4 files changed

+121
-0
lines changed

4 files changed

+121
-0
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriterImpl.kt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.arrow.vector.TinyIntVector
2626
import org.apache.arrow.vector.VarCharVector
2727
import org.apache.arrow.vector.VariableWidthVector
2828
import org.apache.arrow.vector.VectorSchemaRoot
29+
import org.apache.arrow.vector.complex.StructVector
2930
import org.apache.arrow.vector.types.DateUnit
3031
import org.apache.arrow.vector.types.FloatingPointPrecision
3132
import org.apache.arrow.vector.types.pojo.ArrowType
@@ -49,8 +50,10 @@ import org.jetbrains.kotlinx.dataframe.api.convertToShort
4950
import org.jetbrains.kotlinx.dataframe.api.convertToString
5051
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
5152
import org.jetbrains.kotlinx.dataframe.api.map
53+
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
5254
import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException
5355
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
56+
import org.jetbrains.kotlinx.dataframe.indices
5457
import org.jetbrains.kotlinx.dataframe.name
5558
import org.jetbrains.kotlinx.dataframe.values
5659
import kotlin.reflect.full.isSubtypeOf
@@ -72,7 +75,15 @@ internal class ArrowWriterImpl(
7275
private fun allocateVector(vector: FieldVector, size: Int, totalBytes: Long? = null) {
7376
when (vector) {
7477
is FixedWidthVector -> vector.allocateNew(size)
78+
7579
is VariableWidthVector -> totalBytes?.let { vector.allocateNew(it, size) } ?: vector.allocateNew(size)
80+
81+
is StructVector -> {
82+
vector.childrenFromFields.forEach { child ->
83+
allocateVector(child, size)
84+
}
85+
}
86+
7687
else -> throw IllegalArgumentException("Can not allocate ${vector.javaClass.canonicalName}")
7788
}
7889
}
@@ -138,6 +149,8 @@ internal class ArrowWriterImpl(
138149

139150
is ArrowType.Time -> column.convertToLocalTime()
140151

152+
is ArrowType.Struct -> column
153+
141154
else ->
142155
throw NotImplementedError(
143156
"Saving ${targetFieldType.javaClass.canonicalName} is currently not implemented",
@@ -277,6 +290,18 @@ internal class ArrowWriterImpl(
277290
} ?: vector.setNull(i)
278291
}
279292

293+
is StructVector -> {
294+
require(column is ColumnGroup<*>) {
295+
"StructVector expects ColumnGroup, but got ${column::class.simpleName}"
296+
}
297+
298+
column.columns().forEach { childColumn ->
299+
infillVector(vector.getChild(childColumn.name()), childColumn)
300+
}
301+
302+
column.indices.forEach { i -> vector.setIndexDefined(i) }
303+
}
304+
280305
else -> {
281306
// TODO implement other vector types from [readField] (VarBinaryVector, UIntVector, DurationVector, StructVector) and may be others (ListVector, FixedSizeListVector etc)
282307
throw NotImplementedError("Saving to ${vector.javaClass.canonicalName} is currently not implemented")

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowTypesMatching.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.apache.arrow.vector.types.pojo.Field
1111
import org.apache.arrow.vector.types.pojo.FieldType
1212
import org.apache.arrow.vector.types.pojo.Schema
1313
import org.jetbrains.kotlinx.dataframe.AnyCol
14+
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
1415
import org.jetbrains.kotlinx.dataframe.typeClass
1516
import kotlin.reflect.full.isSubtypeOf
1617
import kotlin.reflect.typeOf
@@ -27,6 +28,15 @@ public fun AnyCol.toArrowField(mismatchSubscriber: (ConvertingMismatch) -> Unit
2728
val columnType = column.type()
2829
val nullable = columnType.isMarkedNullable
2930
return when {
31+
column is ColumnGroup<*> -> {
32+
val childFields = column.columns().map { it.toArrowField(mismatchSubscriber) }
33+
Field(
34+
column.name(),
35+
FieldType(nullable, ArrowType.Struct(), null),
36+
childFields,
37+
)
38+
}
39+
3040
columnType.isSubtypeOf(typeOf<String?>()) ->
3141
Field(
3242
column.name(),

dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
3838
import org.jetbrains.kotlinx.dataframe.api.map
3939
import org.jetbrains.kotlinx.dataframe.api.pathOf
4040
import org.jetbrains.kotlinx.dataframe.api.remove
41+
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
4142
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
4243
import org.junit.Assert
4344
import org.junit.Test
@@ -724,4 +725,89 @@ internal class ArrowKtTest {
724725

725726
dataFrame.rowsCount() shouldBe 900
726727
}
728+
729+
@Test
730+
fun testColumnGroupRoundtrip() {
731+
val original = dataFrameOf(
732+
"outer" to columnOf("x", "y", "z"),
733+
"inner" to columnOf(
734+
"nested1" to columnOf("a", "b", "c"),
735+
"nested2" to columnOf(1, 2, 3),
736+
),
737+
)
738+
739+
val featherBytes = original.saveArrowFeatherToByteArray()
740+
val fromFeather = DataFrame.readArrowFeather(featherBytes)
741+
fromFeather shouldBe original
742+
743+
val ipcBytes = original.saveArrowIPCToByteArray()
744+
val fromIpc = DataFrame.readArrowIPC(ipcBytes)
745+
fromIpc shouldBe original
746+
}
747+
748+
@Test
749+
fun testNestedColumnGroupRoundtrip() {
750+
val deeplyNested by columnOf(
751+
"level2" to columnOf(
752+
"level3" to columnOf(1, 2, 3),
753+
),
754+
)
755+
val original = dataFrameOf(deeplyNested)
756+
757+
val bytes = original.saveArrowFeatherToByteArray()
758+
val restored = DataFrame.readArrowFeather(bytes)
759+
760+
restored shouldBe original
761+
}
762+
763+
@Test
764+
fun testColumnGroupWithNulls() {
765+
val group by columnOf(
766+
"a" to columnOf("x", null, "z"),
767+
"b" to columnOf(1, 2, null),
768+
)
769+
val original = dataFrameOf(group)
770+
771+
val bytes = original.saveArrowFeatherToByteArray()
772+
val restored = DataFrame.readArrowFeather(bytes)
773+
774+
restored shouldBe original
775+
}
776+
777+
@Test
778+
fun testReadParquetWithNestedStruct() {
779+
val resourceUrl = testResource("books.parquet")
780+
val resourcePath = resourceUrl.toURI().toPath()
781+
782+
val df = DataFrame.readParquet(resourcePath)
783+
784+
df.columnNames() shouldBe listOf("id", "title", "author", "genre", "publisher")
785+
786+
val authorGroup = df["author"] as ColumnGroup<*>
787+
authorGroup.columnNames() shouldBe listOf("id", "firstName", "lastName")
788+
789+
df["id"].type() shouldBe typeOf<Int>()
790+
df["title"].type() shouldBe typeOf<String>()
791+
df["genre"].type() shouldBe typeOf<String>()
792+
df["publisher"].type() shouldBe typeOf<String>()
793+
authorGroup["id"].type() shouldBe typeOf<Int>()
794+
authorGroup["firstName"].type() shouldBe typeOf<String>()
795+
authorGroup["lastName"].type() shouldBe typeOf<String>()
796+
}
797+
798+
@Test
799+
fun testParquetNestedStructRoundtrip() {
800+
val resourceUrl = testResource("books.parquet")
801+
val resourcePath = resourceUrl.toURI().toPath()
802+
803+
val original = DataFrame.readParquet(resourcePath)
804+
805+
val featherBytes = original.saveArrowFeatherToByteArray()
806+
val fromFeather = DataFrame.readArrowFeather(featherBytes)
807+
fromFeather shouldBe original
808+
809+
val ipcBytes = original.saveArrowIPCToByteArray()
810+
val fromIpc = DataFrame.readArrowIPC(ipcBytes)
811+
fromIpc shouldBe original
812+
}
727813
}
5.61 KB
Binary file not shown.

0 commit comments

Comments
 (0)