Skip to content

Commit fcfae09

Browse files
committed
Implement strictType
1 parent aee6def commit fcfae09

File tree

1 file changed

+54
-13
lines changed
  • dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io

1 file changed

+54
-13
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import org.apache.arrow.vector.util.Text
1818
import org.jetbrains.kotlinx.dataframe.AnyCol
1919
import org.jetbrains.kotlinx.dataframe.DataFrame
2020
import org.jetbrains.kotlinx.dataframe.api.*
21+
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
2122
import java.io.ByteArrayOutputStream
2223
import java.nio.channels.Channels
2324
import java.nio.channels.WritableByteChannel
@@ -96,21 +97,47 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
9697
vector.valueCount = size
9798
}
9899

100+
private fun convertColumnToTarget(column: AnyCol?, targetFieldType: ArrowType): AnyCol? {
101+
if (column == null) return null
102+
return when (targetFieldType) {
103+
ArrowType.Utf8() -> column.convertToString()
104+
ArrowType.LargeUtf8() -> column.convertToString()
105+
ArrowType.Binary(), ArrowType.LargeBinary() -> TODO("Saving var binary is currently not implemented")
106+
ArrowType.Bool() -> column.convertToBoolean()
107+
ArrowType.Int(8, true) -> column.convertTo<Byte>()
108+
ArrowType.Int(16, true) -> column.convertTo<Short>()
109+
ArrowType.Int(32, true) -> column.convertTo<Int>()
110+
ArrowType.Int(64, true) -> column.convertTo<Long>()
111+
// ArrowType.Int(8, false), ArrowType.Int(16, false), ArrowType.Int(32, false), ArrowType.Int(64, false) ->
112+
is ArrowType.Decimal -> column.convertToBigDecimal()
113+
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) -> column.convertToFloat()
114+
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) -> column.convertToDouble()
115+
ArrowType.Date(DateUnit.DAY) -> column.convertToLocalDate()
116+
ArrowType.Date(DateUnit.MILLISECOND) -> column.convertToLocalDateTime()
117+
is ArrowType.Time -> column.convertToLocalTime()
118+
// is ArrowType.Duration ->
119+
// is ArrowType.Struct ->
120+
else -> {
121+
TODO("Saving ${targetFieldType.javaClass.canonicalName} is not implemented")
122+
}
123+
}
124+
}
125+
99126
private fun infillVector(vector: FieldVector, column: AnyCol) {
100127
when (vector) {
101-
is VarCharVector -> column.forEachIndexed { i, value -> value?.let { vector.set(i, Text(value.toString())); value} ?: vector.setNull(i) }
102-
is LargeVarCharVector -> column.forEachIndexed { i, value -> value?.let { vector.set(i, Text(value.toString())); value} ?: vector.setNull(i) }
128+
is VarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text(value)); value} ?: vector.setNull(i) }
129+
is LargeVarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text(value)); value} ?: vector.setNull(i) }
103130
// is VarBinaryVector -> vector.values(range).withType()
104131
// is LargeVarBinaryVector -> vector.values(range).withType()
105132
is BitVector -> column.convertToBoolean().forEachIndexed { i, value -> value?.let { vector.set(i, value.compareTo(false)); value} ?: vector.setNull(i) }
106-
is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
107133
is TinyIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
134+
is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
135+
is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
136+
is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
108137
// is UInt1Vector -> vector.values(range).withType()
109138
// is UInt2Vector -> vector.values(range).withType()
110139
// is UInt4Vector -> vector.values(range).withType()
111140
// is UInt8Vector -> vector.values(range).withType()
112-
is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
113-
is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
114141
is DecimalVector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
115142
is Decimal256Vector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
116143
is Float8Vector -> column.convertToDouble().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
@@ -125,7 +152,7 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
125152
is TimeSecVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000 / 1000).toInt()); value} ?: vector.setNull(i) }
126153
// is StructVector -> vector.values(range).withType()
127154
else -> {
128-
TODO("not fully implemented, ${vector.javaClass.canonicalName}")
155+
TODO("Saving to ${vector.javaClass.canonicalName} is not implemented")
129156
}
130157
}
131158

@@ -135,22 +162,36 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
135162

136163
private fun allocateVectorAndInfill(field: Field, column: AnyCol?, strictType: Boolean, strictNullable: Boolean): FieldVector {
137164
val containNulls = (column == null || column.hasNulls())
138-
val vector = if (!field.isNullable && containNulls) {
165+
// Convert the column to type specified in field. (If we already have target type, convertTo will do nothing)
166+
val (convertedColumn, actualField) = try {
167+
convertColumnToTarget(column, field.type) to field
168+
} catch (e: TypeConversionException) {
169+
if (strictType) {
170+
// If conversion failed but strictType is enabled, throw the exception
171+
throw e
172+
} else {
173+
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
174+
val actualType = listOf(column!!).toArrowSchema().fields.first().fieldType.type
175+
val actualField = Field(field.name, FieldType(field.isNullable, actualType, field.fieldType.dictionary), field.children)
176+
column to actualField
177+
}
178+
}
179+
val vector = if (!actualField.isNullable && containNulls) {
139180
if (strictNullable) {
140-
throw Exception("${field.name} column contains nulls but should be not nullable")
181+
throw Exception("${actualField.name} column contains nulls but should be not nullable")
141182
} else {
142-
Field(field.name, FieldType(true, field.fieldType.type, field.fieldType.dictionary), field.children).createVector(allocator)!!
183+
Field(actualField.name, FieldType(true, actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
143184
}
144185
} else {
145-
field.createVector(allocator)!!
186+
actualField.createVector(allocator)!!
146187
}
147188

148189
allocateVector(vector, dataFrame.rowsCount())
149-
if (column == null) {
150-
check(field.isNullable)
190+
if (convertedColumn == null) {
191+
check(actualField.isNullable)
151192
infillWithNulls(vector, dataFrame.rowsCount())
152193
} else {
153-
infillVector(vector, column)
194+
infillVector(vector, convertedColumn)
154195
}
155196
return vector
156197
}

0 commit comments

Comments
 (0)