Skip to content

Commit cae2786

Browse files
Merge remote-tracking branch 'origin/master'
2 parents 7072e3e + 46c9afd commit cae2786

File tree

47 files changed

+786
-158
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+786
-158
lines changed

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ val modulesUsingJava11 = with(projects) {
157157
dataframeJupyter,
158158
dataframeGeoJupyter,
159159
examples.ideaExamples.titanic,
160-
examples.ideaExamples.unsupportedDataSources,
160+
examples.ideaExamples.unsupportedDataSources.hibernate,
161161
samples,
162162
plugins.dataframeGradlePlugin,
163163
)

core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Render.kt

Lines changed: 0 additions & 48 deletions
This file was deleted.

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Render.kt

Lines changed: 0 additions & 48 deletions
This file was deleted.

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriterImpl.kt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.arrow.vector.TinyIntVector
2626
import org.apache.arrow.vector.VarCharVector
2727
import org.apache.arrow.vector.VariableWidthVector
2828
import org.apache.arrow.vector.VectorSchemaRoot
29+
import org.apache.arrow.vector.complex.StructVector
2930
import org.apache.arrow.vector.types.DateUnit
3031
import org.apache.arrow.vector.types.FloatingPointPrecision
3132
import org.apache.arrow.vector.types.pojo.ArrowType
@@ -49,8 +50,10 @@ import org.jetbrains.kotlinx.dataframe.api.convertToShort
4950
import org.jetbrains.kotlinx.dataframe.api.convertToString
5051
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
5152
import org.jetbrains.kotlinx.dataframe.api.map
53+
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
5254
import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException
5355
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
56+
import org.jetbrains.kotlinx.dataframe.indices
5457
import org.jetbrains.kotlinx.dataframe.name
5558
import org.jetbrains.kotlinx.dataframe.values
5659
import kotlin.reflect.full.isSubtypeOf
@@ -72,7 +75,15 @@ internal class ArrowWriterImpl(
7275
private fun allocateVector(vector: FieldVector, size: Int, totalBytes: Long? = null) {
7376
when (vector) {
7477
is FixedWidthVector -> vector.allocateNew(size)
78+
7579
is VariableWidthVector -> totalBytes?.let { vector.allocateNew(it, size) } ?: vector.allocateNew(size)
80+
81+
is StructVector -> {
82+
vector.childrenFromFields.forEach { child ->
83+
allocateVector(child, size)
84+
}
85+
}
86+
7687
else -> throw IllegalArgumentException("Can not allocate ${vector.javaClass.canonicalName}")
7788
}
7889
}
@@ -138,6 +149,8 @@ internal class ArrowWriterImpl(
138149

139150
is ArrowType.Time -> column.convertToLocalTime()
140151

152+
is ArrowType.Struct -> column
153+
141154
else ->
142155
throw NotImplementedError(
143156
"Saving ${targetFieldType.javaClass.canonicalName} is currently not implemented",
@@ -277,6 +290,18 @@ internal class ArrowWriterImpl(
277290
} ?: vector.setNull(i)
278291
}
279292

293+
is StructVector -> {
294+
require(column is ColumnGroup<*>) {
295+
"StructVector expects ColumnGroup, but got ${column::class.simpleName}"
296+
}
297+
298+
column.columns().forEach { childColumn ->
299+
infillVector(vector.getChild(childColumn.name()), childColumn)
300+
}
301+
302+
column.indices.forEach { i -> vector.setIndexDefined(i) }
303+
}
304+
280305
else -> {
281306
// TODO implement other vector types from [readField] (VarBinaryVector, UIntVector, DurationVector, StructVector) and may be others (ListVector, FixedSizeListVector etc)
282307
throw NotImplementedError("Saving to ${vector.javaClass.canonicalName} is currently not implemented")

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import org.apache.arrow.vector.DateMilliVector
1919
import org.apache.arrow.vector.Decimal256Vector
2020
import org.apache.arrow.vector.DecimalVector
2121
import org.apache.arrow.vector.DurationVector
22+
import org.apache.arrow.vector.FieldVector
2223
import org.apache.arrow.vector.Float4Vector
2324
import org.apache.arrow.vector.Float8Vector
2425
import org.apache.arrow.vector.IntVector
@@ -293,10 +294,16 @@ private fun List<Nothing?>.withTypeNullable(
293294
return this to nothingType(nullable)
294295
}
295296

296-
private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol {
297+
private fun readField(vector: FieldVector, field: Field, nullability: NullabilityOptions): AnyBaseCol {
297298
try {
298-
val range = 0 until root.rowCount
299-
val (list, type) = when (val vector = root.getVector(field)) {
299+
val range = 0 until vector.valueCount
300+
if (vector is StructVector) {
301+
val columns = field.children.map { childField ->
302+
readField(vector.getChild(childField.name), childField, nullability)
303+
}
304+
return DataColumn.createColumnGroup(field.name, columns.toDataFrame())
305+
}
306+
val (list, type) = when (vector) {
300307
is VarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
301308

302309
is LargeVarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
@@ -357,8 +364,6 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi
357364

358365
is TimeStampSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
359366

360-
is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
361-
362367
is NullVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
363368

364369
else -> {
@@ -371,6 +376,9 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi
371376
}
372377
}
373378

379+
private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol =
380+
readField(root.getVector(field), field, nullability)
381+
374382
/**
375383
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel]
376384
*/

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowTypesMatching.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.apache.arrow.vector.types.pojo.Field
1111
import org.apache.arrow.vector.types.pojo.FieldType
1212
import org.apache.arrow.vector.types.pojo.Schema
1313
import org.jetbrains.kotlinx.dataframe.AnyCol
14+
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
1415
import org.jetbrains.kotlinx.dataframe.typeClass
1516
import kotlin.reflect.full.isSubtypeOf
1617
import kotlin.reflect.typeOf
@@ -27,6 +28,15 @@ public fun AnyCol.toArrowField(mismatchSubscriber: (ConvertingMismatch) -> Unit
2728
val columnType = column.type()
2829
val nullable = columnType.isMarkedNullable
2930
return when {
31+
column is ColumnGroup<*> -> {
32+
val childFields = column.columns().map { it.toArrowField(mismatchSubscriber) }
33+
Field(
34+
column.name(),
35+
FieldType(nullable, ArrowType.Struct(), null),
36+
childFields,
37+
)
38+
}
39+
3040
columnType.isSubtypeOf(typeOf<String?>()) ->
3141
Field(
3242
column.name(),

dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.arrow.vector.types.pojo.Field
2525
import org.apache.arrow.vector.types.pojo.FieldType
2626
import org.apache.arrow.vector.types.pojo.Schema
2727
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
28-
import org.apache.arrow.vector.util.Text
2928
import org.duckdb.DuckDBConnection
3029
import org.duckdb.DuckDBResultSet
3130
import org.jetbrains.kotlinx.dataframe.AnyFrame
@@ -39,7 +38,7 @@ import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
3938
import org.jetbrains.kotlinx.dataframe.api.map
4039
import org.jetbrains.kotlinx.dataframe.api.pathOf
4140
import org.jetbrains.kotlinx.dataframe.api.remove
42-
import org.jetbrains.kotlinx.dataframe.api.toColumn
41+
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
4342
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
4443
import org.junit.Assert
4544
import org.junit.Test
@@ -68,13 +67,11 @@ internal class ArrowKtTest {
6867
val df = DataFrame.readArrowFeather(feather)
6968
val a by columnOf("one")
7069
val b by columnOf(2.0)
71-
val c by listOf(
72-
mapOf(
73-
"c1" to Text("inner"),
74-
"c2" to 4.0,
75-
"c3" to 50.0,
76-
) as Map<String, Any?>,
77-
).toColumn()
70+
val c by columnOf(
71+
"c1" to columnOf("inner"),
72+
"c2" to columnOf(4.0),
73+
"c3" to columnOf(50.0),
74+
)
7875
val d by columnOf("four")
7976
val expected = dataFrameOf(a, b, c, d)
8077
df shouldBe expected
@@ -728,4 +725,89 @@ internal class ArrowKtTest {
728725

729726
dataFrame.rowsCount() shouldBe 900
730727
}
728+
729+
@Test
730+
fun testColumnGroupRoundtrip() {
731+
val original = dataFrameOf(
732+
"outer" to columnOf("x", "y", "z"),
733+
"inner" to columnOf(
734+
"nested1" to columnOf("a", "b", "c"),
735+
"nested2" to columnOf(1, 2, 3),
736+
),
737+
)
738+
739+
val featherBytes = original.saveArrowFeatherToByteArray()
740+
val fromFeather = DataFrame.readArrowFeather(featherBytes)
741+
fromFeather shouldBe original
742+
743+
val ipcBytes = original.saveArrowIPCToByteArray()
744+
val fromIpc = DataFrame.readArrowIPC(ipcBytes)
745+
fromIpc shouldBe original
746+
}
747+
748+
@Test
749+
fun testNestedColumnGroupRoundtrip() {
750+
val deeplyNested by columnOf(
751+
"level2" to columnOf(
752+
"level3" to columnOf(1, 2, 3),
753+
),
754+
)
755+
val original = dataFrameOf(deeplyNested)
756+
757+
val bytes = original.saveArrowFeatherToByteArray()
758+
val restored = DataFrame.readArrowFeather(bytes)
759+
760+
restored shouldBe original
761+
}
762+
763+
@Test
764+
fun testColumnGroupWithNulls() {
765+
val group by columnOf(
766+
"a" to columnOf("x", null, "z"),
767+
"b" to columnOf(1, 2, null),
768+
)
769+
val original = dataFrameOf(group)
770+
771+
val bytes = original.saveArrowFeatherToByteArray()
772+
val restored = DataFrame.readArrowFeather(bytes)
773+
774+
restored shouldBe original
775+
}
776+
777+
@Test
778+
fun testReadParquetWithNestedStruct() {
779+
val resourceUrl = testResource("books.parquet")
780+
val resourcePath = resourceUrl.toURI().toPath()
781+
782+
val df = DataFrame.readParquet(resourcePath)
783+
784+
df.columnNames() shouldBe listOf("id", "title", "author", "genre", "publisher")
785+
786+
val authorGroup = df["author"] as ColumnGroup<*>
787+
authorGroup.columnNames() shouldBe listOf("id", "firstName", "lastName")
788+
789+
df["id"].type() shouldBe typeOf<Int>()
790+
df["title"].type() shouldBe typeOf<String>()
791+
df["genre"].type() shouldBe typeOf<String>()
792+
df["publisher"].type() shouldBe typeOf<String>()
793+
authorGroup["id"].type() shouldBe typeOf<Int>()
794+
authorGroup["firstName"].type() shouldBe typeOf<String>()
795+
authorGroup["lastName"].type() shouldBe typeOf<String>()
796+
}
797+
798+
@Test
799+
fun testParquetNestedStructRoundtrip() {
800+
val resourceUrl = testResource("books.parquet")
801+
val resourcePath = resourceUrl.toURI().toPath()
802+
803+
val original = DataFrame.readParquet(resourcePath)
804+
805+
val featherBytes = original.saveArrowFeatherToByteArray()
806+
val fromFeather = DataFrame.readArrowFeather(featherBytes)
807+
fromFeather shouldBe original
808+
809+
val ipcBytes = original.saveArrowIPCToByteArray()
810+
val fromIpc = DataFrame.readArrowIPC(ipcBytes)
811+
fromIpc shouldBe original
812+
}
731813
}
5.61 KB
Binary file not shown.
497 KB
Loading
484 KB
Loading

0 commit comments

Comments
 (0)