Skip to content

Commit b760974

Browse files
committed
Arrow API Refactoring, extract implementation
1 parent 6f38c57 commit b760974

File tree

7 files changed

+684
-633
lines changed

7 files changed

+684
-633
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 16 additions & 436 deletions
Large diffs are not rendered by default.
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
package org.jetbrains.kotlinx.dataframe.io
2+
3+
import kotlinx.datetime.TimeZone
4+
import kotlinx.datetime.toInstant
5+
import kotlinx.datetime.toJavaLocalDate
6+
import org.apache.arrow.memory.RootAllocator
7+
import org.apache.arrow.vector.BaseFixedWidthVector
8+
import org.apache.arrow.vector.BaseVariableWidthVector
9+
import org.apache.arrow.vector.BigIntVector
10+
import org.apache.arrow.vector.BitVector
11+
import org.apache.arrow.vector.DateDayVector
12+
import org.apache.arrow.vector.DateMilliVector
13+
import org.apache.arrow.vector.Decimal256Vector
14+
import org.apache.arrow.vector.DecimalVector
15+
import org.apache.arrow.vector.FieldVector
16+
import org.apache.arrow.vector.FixedWidthVector
17+
import org.apache.arrow.vector.Float4Vector
18+
import org.apache.arrow.vector.Float8Vector
19+
import org.apache.arrow.vector.IntVector
20+
import org.apache.arrow.vector.LargeVarCharVector
21+
import org.apache.arrow.vector.SmallIntVector
22+
import org.apache.arrow.vector.TimeMicroVector
23+
import org.apache.arrow.vector.TimeMilliVector
24+
import org.apache.arrow.vector.TimeNanoVector
25+
import org.apache.arrow.vector.TimeSecVector
26+
import org.apache.arrow.vector.TinyIntVector
27+
import org.apache.arrow.vector.VarCharVector
28+
import org.apache.arrow.vector.VariableWidthVector
29+
import org.apache.arrow.vector.VectorSchemaRoot
30+
import org.apache.arrow.vector.types.DateUnit
31+
import org.apache.arrow.vector.types.FloatingPointPrecision
32+
import org.apache.arrow.vector.types.pojo.ArrowType
33+
import org.apache.arrow.vector.types.pojo.Field
34+
import org.apache.arrow.vector.types.pojo.FieldType
35+
import org.apache.arrow.vector.types.pojo.Schema
36+
import org.apache.arrow.vector.util.Text
37+
import org.jetbrains.kotlinx.dataframe.AnyCol
38+
import org.jetbrains.kotlinx.dataframe.DataFrame
39+
import org.jetbrains.kotlinx.dataframe.name
40+
import org.jetbrains.kotlinx.dataframe.api.convertToBigDecimal
41+
import org.jetbrains.kotlinx.dataframe.api.convertToBoolean
42+
import org.jetbrains.kotlinx.dataframe.api.convertToByte
43+
import org.jetbrains.kotlinx.dataframe.api.convertToDouble
44+
import org.jetbrains.kotlinx.dataframe.api.convertToFloat
45+
import org.jetbrains.kotlinx.dataframe.api.convertToInt
46+
import org.jetbrains.kotlinx.dataframe.api.convertToLocalDate
47+
import org.jetbrains.kotlinx.dataframe.api.convertToLocalDateTime
48+
import org.jetbrains.kotlinx.dataframe.api.convertToLocalTime
49+
import org.jetbrains.kotlinx.dataframe.api.convertToLong
50+
import org.jetbrains.kotlinx.dataframe.api.convertToShort
51+
import org.jetbrains.kotlinx.dataframe.api.convertToString
52+
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
53+
import org.jetbrains.kotlinx.dataframe.api.map
54+
import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException
55+
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
56+
57+
/**
58+
* Save [dataFrame] content in Apache Arrow format (can be written to File, ByteArray, OutputStream or raw Channel) with [targetSchema].
59+
* If [dataFrame] content does not match with [targetSchema], behaviour is specified by [mode], mismatches would be sent to [mismatchSubscriber]
60+
*/
61+
internal class ArrowWriterImpl(
62+
override val dataFrame: DataFrame<*>,
63+
override val targetSchema: Schema,
64+
override val mode: ArrowWriter.Companion.Mode,
65+
override val mismatchSubscriber: (ConvertingMismatch) -> Unit = ignoreMismatchMessage
66+
): ArrowWriter {
67+
68+
private val allocator = RootAllocator()
69+
70+
private fun allocateVector(vector: FieldVector, size: Int) {
71+
when (vector) {
72+
is FixedWidthVector -> vector.allocateNew(size)
73+
is VariableWidthVector -> vector.allocateNew(size)
74+
else -> TODO("Not implemented for ${vector.javaClass.canonicalName}")
75+
}
76+
}
77+
78+
private fun infillWithNulls(vector: FieldVector, size: Int) {
79+
when (vector) {
80+
is BaseFixedWidthVector -> for (i in 0 until size) { vector.setNull(i) }
81+
is BaseVariableWidthVector -> for (i in 0 until size) { vector.setNull(i) }
82+
else -> TODO("Not implemented for ${vector.javaClass.canonicalName}")
83+
}
84+
vector.valueCount = size
85+
}
86+
87+
private fun convertColumnToTarget(column: AnyCol?, targetFieldType: ArrowType): AnyCol? {
88+
if (column == null) return null
89+
return when (targetFieldType) {
90+
ArrowType.Utf8() -> column.map { it?.toString() }
91+
ArrowType.LargeUtf8() -> column.map { it?.toString() }
92+
ArrowType.Binary(), ArrowType.LargeBinary() -> TODO("Saving var binary is currently not implemented")
93+
ArrowType.Bool() -> column.convertToBoolean()
94+
ArrowType.Int(8, true) -> column.convertToByte()
95+
ArrowType.Int(16, true) -> column.convertToShort()
96+
ArrowType.Int(32, true) -> column.convertToInt()
97+
ArrowType.Int(64, true) -> column.convertToLong()
98+
// ArrowType.Int(8, false), ArrowType.Int(16, false), ArrowType.Int(32, false), ArrowType.Int(64, false) -> todo
99+
is ArrowType.Decimal -> column.convertToBigDecimal()
100+
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) -> column.convertToDouble().convertToFloat() //Use [convertToDouble] as locale logic step
101+
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) -> column.convertToDouble()
102+
ArrowType.Date(DateUnit.DAY) -> column.convertToLocalDate()
103+
ArrowType.Date(DateUnit.MILLISECOND) -> column.convertToLocalDateTime()
104+
is ArrowType.Time -> column.convertToLocalTime()
105+
// is ArrowType.Duration -> todo
106+
// is ArrowType.Struct -> todo
107+
else -> {
108+
TODO("Saving ${targetFieldType.javaClass.canonicalName} is not implemented")
109+
}
110+
}
111+
}
112+
113+
private fun infillVector(vector: FieldVector, column: AnyCol) {
114+
when (vector) {
115+
is VarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text(value)); value } ?: vector.setNull(i) }
116+
is LargeVarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text(value)); value } ?: vector.setNull(i) }
117+
// is VarBinaryVector -> todo
118+
// is LargeVarBinaryVector -> todo
119+
is BitVector -> column.convertToBoolean().forEachIndexed { i, value -> value?.let { vector.set(i, value.compareTo(false)); value } ?: vector.setNull(i) }
120+
is TinyIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
121+
is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
122+
is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
123+
is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
124+
// is UInt1Vector -> todo
125+
// is UInt2Vector -> todo
126+
// is UInt4Vector -> todo
127+
// is UInt8Vector -> todo
128+
is DecimalVector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
129+
is Decimal256Vector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
130+
is Float8Vector -> column.convertToDouble().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
131+
is Float4Vector -> column.convertToFloat().forEachIndexed { i, value -> value?.let { vector.set(i, value); value } ?: vector.setNull(i) }
132+
133+
is DateDayVector -> column.convertToLocalDate().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toJavaLocalDate().toEpochDay()).toInt()); value } ?: vector.setNull(i) }
134+
is DateMilliVector -> column.convertToLocalDateTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toInstant(
135+
TimeZone.UTC).toEpochMilliseconds()); value } ?: vector.setNull(i) }
136+
// is DurationVector -> todo
137+
is TimeNanoVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay()); value } ?: vector.setNull(i) }
138+
is TimeMicroVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay() / 1000); value } ?: vector.setNull(i) }
139+
is TimeMilliVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000).toInt()); value } ?: vector.setNull(i) }
140+
is TimeSecVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000 / 1000).toInt()); value } ?: vector.setNull(i) }
141+
// is StructVector -> todo
142+
else -> {
143+
TODO("Saving to ${vector.javaClass.canonicalName} is not implemented")
144+
}
145+
}
146+
147+
vector.valueCount = dataFrame.rowsCount()
148+
}
149+
150+
/**
151+
* Create Arrow FieldVector with [column] content cast to [field] type according to [strictType] and [strictNullable] settings.
152+
*/
153+
private fun allocateVectorAndInfill(field: Field, column: AnyCol?, strictType: Boolean, strictNullable: Boolean): FieldVector {
154+
val containNulls = (column == null || column.hasNulls())
155+
// Convert the column to type specified in field. (If we already have target type, convertTo will do nothing)
156+
157+
val (convertedColumn, actualField) = try {
158+
convertColumnToTarget(column, field.type) to field
159+
} catch (e: CellConversionException) {
160+
if (strictType) {
161+
// If conversion failed but strictType is enabled, throw the exception
162+
val mismatch = ConvertingMismatch.TypeConversionFail.ConversionFailError(e.column, e.row, e)
163+
mismatchSubscriber(mismatch)
164+
throw ConvertingException(mismatch)
165+
} else {
166+
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
167+
mismatchSubscriber(ConvertingMismatch.TypeConversionFail.ConversionFailIgnored(e.column, e.row, e))
168+
column to column!!.toArrowField(mismatchSubscriber)
169+
}
170+
} catch (e: TypeConverterNotFoundException) {
171+
if (strictType) {
172+
// If conversion failed but strictType is enabled, throw the exception
173+
val mismatch = ConvertingMismatch.TypeConversionNotFound.ConversionNotFoundError(field.name, e)
174+
mismatchSubscriber(mismatch)
175+
throw ConvertingException(mismatch)
176+
} else {
177+
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
178+
mismatchSubscriber(ConvertingMismatch.TypeConversionNotFound.ConversionNotFoundIgnored(field.name, e))
179+
column to column!!.toArrowField(mismatchSubscriber)
180+
}
181+
}
182+
183+
val vector = if (!actualField.isNullable && containNulls) {
184+
var firstNullValue: Int? = null;
185+
for (i in 0 until (column?.size() ?: -1)) {
186+
if (column!![i] == null) {
187+
firstNullValue = i;
188+
break;
189+
}
190+
}
191+
if (strictNullable) {
192+
val mismatch = ConvertingMismatch.NullableMismatch.NullValueError(actualField.name, firstNullValue)
193+
mismatchSubscriber(mismatch)
194+
throw ConvertingException(mismatch)
195+
} else {
196+
mismatchSubscriber(ConvertingMismatch.NullableMismatch.NullValueIgnored(actualField.name, firstNullValue))
197+
Field(actualField.name, FieldType(true, actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
198+
}
199+
} else {
200+
actualField.createVector(allocator)!!
201+
}
202+
203+
allocateVector(vector, dataFrame.rowsCount())
204+
if (convertedColumn == null) {
205+
check(actualField.isNullable)
206+
infillWithNulls(vector, dataFrame.rowsCount())
207+
} else {
208+
infillVector(vector, convertedColumn)
209+
}
210+
return vector
211+
}
212+
213+
private fun List<AnyCol>.toVectors(): List<FieldVector> = this.map {
214+
val field = it.toArrowField(mismatchSubscriber)
215+
allocateVectorAndInfill(field, it, true, true)
216+
}
217+
218+
override fun allocateVectorSchemaRoot(): VectorSchemaRoot {
219+
val mainVectors = LinkedHashMap<String, FieldVector>()
220+
try {
221+
for (field in targetSchema.fields) {
222+
val column = dataFrame.getColumnOrNull(field.name)
223+
if (column == null && !field.isNullable) {
224+
if (mode.restrictNarrowing) {
225+
val mismatch = ConvertingMismatch.NarrowingMismatch.NotPresentedColumnError(field.name)
226+
mismatchSubscriber(mismatch)
227+
throw ConvertingException(mismatch)
228+
} else {
229+
mismatchSubscriber(ConvertingMismatch.NarrowingMismatch.NotPresentedColumnIgnored(field.name))
230+
continue
231+
}
232+
}
233+
234+
val vector = allocateVectorAndInfill(field, column, mode.strictType, mode.strictNullable)
235+
mainVectors[field.name] = vector
236+
}
237+
} catch (e: Exception) {
238+
mainVectors.values.forEach { it.close() } // Clear buffers before throwing exception
239+
throw e
240+
}
241+
val vectors = ArrayList<FieldVector>()
242+
vectors.addAll(mainVectors.values)
243+
val otherColumns = dataFrame.columns().filter { column -> !mainVectors.containsKey(column.name()) }
244+
if (!mode.restrictWidening) {
245+
vectors.addAll(otherColumns.toVectors())
246+
otherColumns.forEach {
247+
mismatchSubscriber(ConvertingMismatch.WideningMismatch.AddedColumn(it.name))
248+
}
249+
} else {
250+
otherColumns.forEach {
251+
mismatchSubscriber(ConvertingMismatch.WideningMismatch.RejectedColumn(it.name))
252+
}
253+
}
254+
return VectorSchemaRoot(vectors)
255+
}
256+
257+
override fun close() {
258+
allocator.close()
259+
}
260+
}

0 commit comments

Comments
 (0)