@@ -18,6 +18,7 @@ import org.apache.arrow.vector.util.Text
18
18
import org.jetbrains.kotlinx.dataframe.AnyCol
19
19
import org.jetbrains.kotlinx.dataframe.DataFrame
20
20
import org.jetbrains.kotlinx.dataframe.api.*
21
+ import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
21
22
import java.io.ByteArrayOutputStream
22
23
import java.nio.channels.Channels
23
24
import java.nio.channels.WritableByteChannel
@@ -96,21 +97,47 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
96
97
vector.valueCount = size
97
98
}
98
99
100
+ private fun convertColumnToTarget (column : AnyCol ? , targetFieldType : ArrowType ): AnyCol ? {
101
+ if (column == null ) return null
102
+ return when (targetFieldType) {
103
+ ArrowType .Utf8 () -> column.convertToString()
104
+ ArrowType .LargeUtf8 () -> column.convertToString()
105
+ ArrowType .Binary (), ArrowType .LargeBinary () -> TODO (" Saving var binary is currently not implemented" )
106
+ ArrowType .Bool () -> column.convertToBoolean()
107
+ ArrowType .Int (8 , true ) -> column.convertTo<Byte >()
108
+ ArrowType .Int (16 , true ) -> column.convertTo<Short >()
109
+ ArrowType .Int (32 , true ) -> column.convertTo<Int >()
110
+ ArrowType .Int (64 , true ) -> column.convertTo<Long >()
111
+ // ArrowType.Int(8, false), ArrowType.Int(16, false), ArrowType.Int(32, false), ArrowType.Int(64, false) ->
112
+ is ArrowType .Decimal -> column.convertToBigDecimal()
113
+ ArrowType .FloatingPoint (FloatingPointPrecision .SINGLE ) -> column.convertToFloat()
114
+ ArrowType .FloatingPoint (FloatingPointPrecision .DOUBLE ) -> column.convertToDouble()
115
+ ArrowType .Date (DateUnit .DAY ) -> column.convertToLocalDate()
116
+ ArrowType .Date (DateUnit .MILLISECOND ) -> column.convertToLocalDateTime()
117
+ is ArrowType .Time -> column.convertToLocalTime()
118
+ // is ArrowType.Duration ->
119
+ // is ArrowType.Struct ->
120
+ else -> {
121
+ TODO (" Saving ${targetFieldType.javaClass.canonicalName} is not implemented" )
122
+ }
123
+ }
124
+ }
125
+
99
126
private fun infillVector (vector : FieldVector , column : AnyCol ) {
100
127
when (vector) {
101
- is VarCharVector -> column.forEachIndexed { i, value -> value?.let { vector.set(i, Text (value.toString() )); value} ? : vector.setNull(i) }
102
- is LargeVarCharVector -> column.forEachIndexed { i, value -> value?.let { vector.set(i, Text (value.toString() )); value} ? : vector.setNull(i) }
128
+ is VarCharVector -> column.convertToString(). forEachIndexed { i, value -> value?.let { vector.set(i, Text (value)); value} ? : vector.setNull(i) }
129
+ is LargeVarCharVector -> column.convertToString(). forEachIndexed { i, value -> value?.let { vector.set(i, Text (value)); value} ? : vector.setNull(i) }
103
130
// is VarBinaryVector -> vector.values(range).withType()
104
131
// is LargeVarBinaryVector -> vector.values(range).withType()
105
132
is BitVector -> column.convertToBoolean().forEachIndexed { i, value -> value?.let { vector.set(i, value.compareTo(false )); value} ? : vector.setNull(i) }
106
- is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
107
133
is TinyIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
134
+ is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
135
+ is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
136
+ is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
108
137
// is UInt1Vector -> vector.values(range).withType()
109
138
// is UInt2Vector -> vector.values(range).withType()
110
139
// is UInt4Vector -> vector.values(range).withType()
111
140
// is UInt8Vector -> vector.values(range).withType()
112
- is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
113
- is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
114
141
is DecimalVector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
115
142
is Decimal256Vector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
116
143
is Float8Vector -> column.convertToDouble().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
@@ -125,7 +152,7 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
125
152
is TimeSecVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000 / 1000 ).toInt()); value} ? : vector.setNull(i) }
126
153
// is StructVector -> vector.values(range).withType()
127
154
else -> {
128
- TODO (" not fully implemented, ${vector.javaClass.canonicalName} " )
155
+ TODO (" Saving to ${vector.javaClass.canonicalName} is not implemented " )
129
156
}
130
157
}
131
158
@@ -135,22 +162,36 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
135
162
136
163
private fun allocateVectorAndInfill (field : Field , column : AnyCol ? , strictType : Boolean , strictNullable : Boolean ): FieldVector {
137
164
val containNulls = (column == null || column.hasNulls())
138
- val vector = if (! field.isNullable && containNulls) {
165
+ // Convert the column to type specified in field. (If we already have target type, convertTo will do nothing)
166
+ val (convertedColumn, actualField) = try {
167
+ convertColumnToTarget(column, field.type) to field
168
+ } catch (e: TypeConversionException ) {
169
+ if (strictType) {
170
+ // If conversion failed but strictType is enabled, throw the exception
171
+ throw e
172
+ } else {
173
+ // If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
174
+ val actualType = listOf (column!! ).toArrowSchema().fields.first().fieldType.type
175
+ val actualField = Field (field.name, FieldType (field.isNullable, actualType, field.fieldType.dictionary), field.children)
176
+ column to actualField
177
+ }
178
+ }
179
+ val vector = if (! actualField.isNullable && containNulls) {
139
180
if (strictNullable) {
140
- throw Exception (" ${field .name} column contains nulls but should be not nullable" )
181
+ throw Exception (" ${actualField .name} column contains nulls but should be not nullable" )
141
182
} else {
142
- Field (field .name, FieldType (true , field .fieldType.type, field .fieldType.dictionary), field .children).createVector(allocator)!!
183
+ Field (actualField .name, FieldType (true , actualField .fieldType.type, actualField .fieldType.dictionary), actualField .children).createVector(allocator)!!
143
184
}
144
185
} else {
145
- field .createVector(allocator)!!
186
+ actualField .createVector(allocator)!!
146
187
}
147
188
148
189
allocateVector(vector, dataFrame.rowsCount())
149
- if (column == null ) {
150
- check(field .isNullable)
190
+ if (convertedColumn == null ) {
191
+ check(actualField .isNullable)
151
192
infillWithNulls(vector, dataFrame.rowsCount())
152
193
} else {
153
- infillVector(vector, column )
194
+ infillVector(vector, convertedColumn )
154
195
}
155
196
return vector
156
197
}
0 commit comments