Skip to content

Commit 5ae26a7

Browse files
committed
Arrow Writer draft
1 parent f219cc6 commit 5ae26a7

File tree

4 files changed

+230
-1
lines changed

4 files changed

+230
-1
lines changed

dataframe-arrow/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies {
1414
implementation(libs.arrow.memory)
1515
implementation(libs.commonsCompress)
1616
implementation(libs.kotlin.reflect)
17+
implementation(libs.kotlin.datetimeJvm)
1718

1819
testApi(project(":core"))
1920
testImplementation(libs.junit)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package org.jetbrains.kotlinx.dataframe.io
2+
3+
import kotlinx.datetime.TimeZone
4+
import kotlinx.datetime.toInstant
5+
import kotlinx.datetime.toJavaLocalDate
6+
import org.apache.arrow.memory.RootAllocator
7+
import org.apache.arrow.vector.*
8+
import org.apache.arrow.vector.types.DateUnit
9+
import org.apache.arrow.vector.types.FloatingPointPrecision
10+
import org.apache.arrow.vector.types.TimeUnit
11+
import org.apache.arrow.vector.types.pojo.ArrowType
12+
import org.apache.arrow.vector.types.pojo.Field
13+
import org.apache.arrow.vector.types.pojo.FieldType
14+
import org.apache.arrow.vector.types.pojo.Schema
15+
import org.apache.arrow.vector.util.Text
16+
import org.jetbrains.kotlinx.dataframe.AnyCol
17+
import org.jetbrains.kotlinx.dataframe.DataFrame
18+
import org.jetbrains.kotlinx.dataframe.api.*
19+
import java.time.LocalDate
20+
import java.time.LocalDateTime
21+
import java.time.LocalTime
22+
import kotlin.reflect.typeOf
23+
24+
public fun List<AnyCol>.toArrowSchema(): Schema {
25+
val fields = this.map { column ->
26+
when (column.type()) {
27+
typeOf<String?>() -> Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
28+
typeOf<String>() -> Field(column.name(), FieldType(false, ArrowType.Utf8(), null), emptyList())
29+
30+
typeOf<Boolean?>() -> Field(column.name(), FieldType(true, ArrowType.Bool(), null), emptyList())
31+
typeOf<Boolean>() -> Field(column.name(), FieldType(false, ArrowType.Bool(), null), emptyList())
32+
33+
typeOf<Byte?>() -> Field(column.name(), FieldType(true, ArrowType.Int(8, true), null), emptyList())
34+
typeOf<Byte>() -> Field(column.name(), FieldType(false, ArrowType.Int(8, true), null), emptyList())
35+
36+
typeOf<Short?>() -> Field(column.name(), FieldType(true, ArrowType.Int(16, true), null), emptyList())
37+
typeOf<Short>() -> Field(column.name(), FieldType(false, ArrowType.Int(16, true), null), emptyList())
38+
39+
typeOf<Int?>() -> Field(column.name(), FieldType(true, ArrowType.Int(32, true), null), emptyList())
40+
typeOf<Int>() -> Field(column.name(), FieldType(false, ArrowType.Int(32, true), null), emptyList())
41+
42+
typeOf<Long?>() -> Field(column.name(), FieldType(true, ArrowType.Int(64, true), null), emptyList())
43+
typeOf<Long>() -> Field(column.name(), FieldType(false, ArrowType.Int(64, true), null), emptyList())
44+
45+
typeOf<Float?>() -> Field(column.name(), FieldType(true, ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null), emptyList())
46+
typeOf<Float>() -> Field(column.name(), FieldType(false, ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null), emptyList())
47+
48+
typeOf<Double?>() -> Field(column.name(), FieldType(true, ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null), emptyList())
49+
typeOf<Double>() -> Field(column.name(), FieldType(false, ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null), emptyList())
50+
51+
typeOf<LocalDate?>(), typeOf<kotlinx.datetime.LocalDate?>() -> Field(column.name(), FieldType(true, ArrowType.Date(DateUnit.DAY), null), emptyList())
52+
typeOf<LocalDate>(), typeOf<kotlinx.datetime.LocalDate>() -> Field(column.name(), FieldType(false, ArrowType.Date(DateUnit.DAY), null), emptyList())
53+
54+
typeOf<LocalDateTime?>(), typeOf<kotlinx.datetime.LocalDateTime?>() -> Field(column.name(), FieldType(true, ArrowType.Date(DateUnit.MILLISECOND), null), emptyList())
55+
typeOf<LocalDateTime>(), typeOf<kotlinx.datetime.LocalDateTime>() -> Field(column.name(), FieldType(false, ArrowType.Date(DateUnit.MILLISECOND), null), emptyList())
56+
57+
typeOf<LocalTime?>() -> Field(column.name(), FieldType(true, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
58+
typeOf<LocalTime>() -> Field(column.name(), FieldType(false, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
59+
60+
else -> Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
61+
}
62+
}
63+
return Schema(fields)
64+
}
65+
66+
/**
67+
* Save [dataFrame] content in Apache Arrow format (can be written to File, ByteArray or stream) with [targetSchema].
68+
*
69+
*/
70+
public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSchema: Schema): AutoCloseable {
71+
private val allocator = RootAllocator()
72+
73+
private fun allocateVector(vector: FieldVector, size: Int) {
74+
when (vector) {
75+
is FixedWidthVector -> vector.allocateNew(size)
76+
is VariableWidthVector -> vector.allocateNew(size)
77+
else -> TODO("Not implemented for ${vector.javaClass.canonicalName}")
78+
}
79+
}
80+
81+
private fun infillWithNulls(vector: FieldVector, size: Int) {
82+
when (vector) {
83+
is BaseFixedWidthVector -> for ( i in 0 until size) { vector.setNull(i) }
84+
is BaseVariableWidthVector -> for ( i in 0 until size) { vector.setNull(i) }
85+
else -> TODO("Not implemented for ${vector.javaClass.canonicalName}")
86+
}
87+
vector.valueCount = size
88+
}
89+
90+
private fun infillVector(vector: FieldVector, column: AnyCol) {
91+
when (vector) {
92+
is VarCharVector -> column.forEachIndexed { i, value -> value?.let { vector.set(i, Text(value.toString())); value} ?: vector.setNull(i) }
93+
is LargeVarCharVector -> column.forEachIndexed { i, value -> value?.let { vector.set(i, Text(value.toString())); value} ?: vector.setNull(i) }
94+
// is VarBinaryVector -> vector.values(range).withType()
95+
// is LargeVarBinaryVector -> vector.values(range).withType()
96+
is BitVector -> column.convertToBoolean().forEachIndexed { i, value -> value?.let { vector.set(i, value.compareTo(false)); value} ?: vector.setNull(i) }
97+
is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
98+
is TinyIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
99+
// is UInt1Vector -> vector.values(range).withType()
100+
// is UInt2Vector -> vector.values(range).withType()
101+
// is UInt4Vector -> vector.values(range).withType()
102+
// is UInt8Vector -> vector.values(range).withType()
103+
is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
104+
is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
105+
is DecimalVector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
106+
is Decimal256Vector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
107+
is Float8Vector -> column.convertToDouble().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
108+
is Float4Vector -> column.convertToFloat().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
109+
110+
is DateDayVector -> column.convertToLocalDate().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toJavaLocalDate().toEpochDay()).toInt()); value} ?: vector.setNull(i) }
111+
is DateMilliVector -> column.convertToLocalDateTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toInstant(TimeZone.UTC).toEpochMilliseconds()); value} ?: vector.setNull(i) }
112+
// is DurationVector -> vector.values(range).withType()
113+
is TimeNanoVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay()); value} ?: vector.setNull(i) }
114+
is TimeMicroVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay() / 1000); value} ?: vector.setNull(i) }
115+
is TimeMilliVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000).toInt()); value} ?: vector.setNull(i) }
116+
is TimeSecVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000 / 1000).toInt()); value} ?: vector.setNull(i) }
117+
// is StructVector -> vector.values(range).withType()
118+
else -> {
119+
TODO("not fully implemented, ${vector.javaClass.canonicalName}")
120+
}
121+
}
122+
123+
vector.valueCount = dataFrame.rowsCount()
124+
125+
}
126+
127+
private fun List<AnyCol>.toVectors(): List<FieldVector> {
128+
val actualSchema = this.toArrowSchema()
129+
val vectors = ArrayList<FieldVector>()
130+
for ((i, field) in actualSchema.fields.withIndex()) {
131+
val column = this[i]
132+
val vector = field.createVector(allocator)!!
133+
allocateVector(vector, dataFrame.rowsCount())
134+
infillVector(vector, column)
135+
vectors.add(vector)
136+
}
137+
return vectors
138+
}
139+
140+
/**
141+
* Create Arrow VectorSchemaRoot with [dataFrame] content casted to [targetSchema].
142+
* If [restrictWidening] is true, [dataFrame] columns not described in [targetSchema] would not be saved (otherwise, would be saved as is).
143+
* If [restrictNarrowing] is true, [targetSchema] fields that are not nullable and do not exist in [dataFrame] will produce exception (otherwise, would not be saved).
144+
* If [strictType] is true, [dataFrame] columns described in [targetSchema] with non-compatible type will produce exception (otherwise, would be saved as is).
145+
* If [strictNullable] is true, [targetSchema] fields that are not nullable and contain nulls in [dataFrame] will produce exception (otherwise, would be saved as is with nullable = true).
146+
*/
147+
private fun allocateVectorSchemaRoot(
148+
restrictWidening: Boolean = true,
149+
restrictNarrowing: Boolean = true,
150+
strictType: Boolean = true,
151+
strictNullable: Boolean = true
152+
): VectorSchemaRoot {
153+
val mainVectors = LinkedHashMap<String, FieldVector>()
154+
for (field in targetSchema.fields) {
155+
val column = dataFrame.getColumnOrNull(field.name)
156+
val vector = field.createVector(allocator)!!
157+
if (column == null && !field.isNullable) {
158+
if (restrictNarrowing) {
159+
throw Exception("${field.name} column is not presented")
160+
} else {
161+
continue
162+
}
163+
}
164+
165+
allocateVector(vector, dataFrame.rowsCount())
166+
if (column == null) {
167+
check(field.isNullable)
168+
infillWithNulls(vector, dataFrame.rowsCount())
169+
} else {
170+
infillVector(vector, column)
171+
}
172+
mainVectors[field.name] = vector
173+
}
174+
val vectors = ArrayList<FieldVector>()
175+
vectors.addAll(mainVectors.values)
176+
if (!restrictWidening) {
177+
val otherVectors = dataFrame.columns().filter { column -> !mainVectors.containsKey(column.name()) }.toVectors()
178+
vectors.addAll(otherVectors)
179+
}
180+
return VectorSchemaRoot(vectors)
181+
}
182+
183+
override fun close() {
184+
allocator.close()
185+
}
186+
}

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadDfMethod
5050
import org.jetbrains.kotlinx.dataframe.impl.asList
5151
import java.io.File
5252
import java.io.InputStream
53+
import java.io.OutputStream
5354
import java.math.BigDecimal
5455
import java.math.BigInteger
5556
import java.net.URL
@@ -399,3 +400,44 @@ public fun DataFrame.Companion.readArrowFeather(
399400
} else {
400401
readArrowFeather(File(path), nullability)
401402
}
403+
404+
//// IPC saving block
405+
//
406+
///**
407+
// * Save data to [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format), write to new or existing [file].
408+
// * If file exists, it can be recreated or expanded.
409+
// */
410+
//public fun AnyFrame.writeArrowIPC(file: File, append: Boolean = true) {
411+
//
412+
//}
413+
//
414+
///**
415+
// * Save data to [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format), write to [ByteArray]
416+
// */
417+
//public fun AnyFrame.writeArrowIPCToByteArray() {
418+
//
419+
//}
420+
//
421+
//// Feather saving block
422+
//
423+
///**
424+
// * Save data to [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files), write to new or existing [file].
425+
// * If file exists, it would be recreated.
426+
// */
427+
//public fun AnyFrame.writeArrowFeather(file: File) {
428+
//
429+
//}
430+
//
431+
///**
432+
// * Save data to [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files), write to [ByteArray]
433+
// */
434+
//public fun DataFrame.Companion.writeArrowFeatherToByteArray(): ByteArray {
435+
//
436+
//}
437+
//
438+
///**
439+
// * Write [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) from existing [stream]
440+
// */
441+
//public fun DataFrame.Companion.writeArrowFeather(stream: OutputStream) {
442+
//
443+
//}

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ openapi = "2.1.2"
2323
junit = "4.13.2"
2424
kotestAsserions = "4.6.3"
2525
jsoup = "1.14.3"
26-
arrow = "8.0.0"
26+
arrow = "9.0.0"
2727

2828
[libraries]
2929
ksp-gradle = { group = "com.google.devtools.ksp", name = "symbol-processing-gradle-plugin", version.ref = "ksp" }

0 commit comments

Comments
 (0)