@@ -4,7 +4,29 @@ import kotlinx.datetime.TimeZone
4
4
import kotlinx.datetime.toInstant
5
5
import kotlinx.datetime.toJavaLocalDate
6
6
import org.apache.arrow.memory.RootAllocator
7
- import org.apache.arrow.vector.*
7
+ import org.apache.arrow.vector.BaseFixedWidthVector
8
+ import org.apache.arrow.vector.BaseVariableWidthVector
9
+ import org.apache.arrow.vector.FieldVector
10
+ import org.apache.arrow.vector.FixedWidthVector
11
+ import org.apache.arrow.vector.LargeVarCharVector
12
+ import org.apache.arrow.vector.TinyIntVector
13
+ import org.apache.arrow.vector.SmallIntVector
14
+ import org.apache.arrow.vector.IntVector
15
+ import org.apache.arrow.vector.BigIntVector
16
+ import org.apache.arrow.vector.BitVector
17
+ import org.apache.arrow.vector.DateDayVector
18
+ import org.apache.arrow.vector.DateMilliVector
19
+ import org.apache.arrow.vector.DecimalVector
20
+ import org.apache.arrow.vector.Decimal256Vector
21
+ import org.apache.arrow.vector.Float4Vector
22
+ import org.apache.arrow.vector.Float8Vector
23
+ import org.apache.arrow.vector.TimeMicroVector
24
+ import org.apache.arrow.vector.TimeMilliVector
25
+ import org.apache.arrow.vector.TimeNanoVector
26
+ import org.apache.arrow.vector.TimeSecVector
27
+ import org.apache.arrow.vector.VariableWidthVector
28
+ import org.apache.arrow.vector.VarCharVector
29
+ import org.apache.arrow.vector.VectorSchemaRoot
8
30
import org.apache.arrow.vector.ipc.ArrowFileWriter
9
31
import org.apache.arrow.vector.ipc.ArrowStreamWriter
10
32
import org.apache.arrow.vector.types.DateUnit
@@ -18,8 +40,22 @@ import org.apache.arrow.vector.util.Text
18
40
import org.jetbrains.kotlinx.dataframe.AnyCol
19
41
import org.jetbrains.kotlinx.dataframe.AnyFrame
20
42
import org.jetbrains.kotlinx.dataframe.DataFrame
21
- import org.jetbrains.kotlinx.dataframe.api.*
43
+ import org.jetbrains.kotlinx.dataframe.api.convertTo
44
+ import org.jetbrains.kotlinx.dataframe.api.convertToBoolean
45
+ import org.jetbrains.kotlinx.dataframe.api.convertToBigDecimal
46
+ import org.jetbrains.kotlinx.dataframe.api.convertToDouble
47
+ import org.jetbrains.kotlinx.dataframe.api.convertToFloat
48
+ import org.jetbrains.kotlinx.dataframe.api.convertToLong
49
+ import org.jetbrains.kotlinx.dataframe.api.convertToInt
50
+ import org.jetbrains.kotlinx.dataframe.api.convertToLocalDate
51
+ import org.jetbrains.kotlinx.dataframe.api.convertToLocalTime
52
+ import org.jetbrains.kotlinx.dataframe.api.convertToLocalDateTime
53
+ import org.jetbrains.kotlinx.dataframe.api.convertToString
54
+ import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
22
55
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
56
+ import org.jetbrains.kotlinx.dataframe.typeClass
57
+ import org.slf4j.Logger
58
+ import org.slf4j.LoggerFactory
23
59
import java.io.ByteArrayOutputStream
24
60
import java.io.File
25
61
import java.io.FileOutputStream
@@ -31,11 +67,13 @@ import java.time.LocalDateTime
31
67
import java.time.LocalTime
32
68
import kotlin.reflect.typeOf
33
69
70
+ private val writeWarningMessage: (String ) -> Unit = {message: String -> System .err.println (message)}
71
+
34
72
/* *
35
73
* Create Arrow [Schema] matching [this] actual data.
36
74
* Columns with not supported types will be interpreted as String
37
75
*/
38
- public fun List<AnyCol>.toArrowSchema (): Schema {
76
+ public fun List<AnyCol>.toArrowSchema (warningSubscriber : ( String ) -> Unit = writeWarningMessage ): Schema {
39
77
val fields = this .map { column ->
40
78
when (column.type()) {
41
79
typeOf<String ?>() -> Field (column.name(), FieldType (true , ArrowType .Utf8 (), null ), emptyList())
@@ -71,7 +109,10 @@ public fun List<AnyCol>.toArrowSchema(): Schema {
71
109
typeOf<LocalTime ?>() -> Field (column.name(), FieldType (true , ArrowType .Time (TimeUnit .NANOSECOND , 64 ), null ), emptyList())
72
110
typeOf<LocalTime >() -> Field (column.name(), FieldType (false , ArrowType .Time (TimeUnit .NANOSECOND , 64 ), null ), emptyList())
73
111
74
- else -> Field (column.name(), FieldType (true , ArrowType .Utf8 (), null ), emptyList())
112
+ else -> {
113
+ warningSubscriber(" Column ${column.name()} has type ${column.typeClass.java.canonicalName} , will be saved as String" )
114
+ Field (column.name(), FieldType (true , ArrowType .Utf8 (), null ), emptyList())
115
+ }
75
116
}
76
117
}
77
118
return Schema (fields)
@@ -85,13 +126,22 @@ public fun DataFrame<*>.arrowWriter(): ArrowWriter = this.arrowWriter(this.colum
85
126
/* *
86
127
* Create [ArrowWriter] for [this] DataFrame with explicit [targetSchema]
87
128
*/
88
- public fun DataFrame <* >.arrowWriter (targetSchema : Schema , mode : ArrowWriter .Companion .Mode = ArrowWriter .Companion .Mode .STRICT ): ArrowWriter = ArrowWriter (this , targetSchema, mode)
129
+ public fun DataFrame <* >.arrowWriter (
130
+ targetSchema : Schema ,
131
+ mode : ArrowWriter .Companion .Mode = ArrowWriter .Companion .Mode .STRICT ,
132
+ warningSubscriber : (String ) -> Unit = writeWarningMessage
133
+ ): ArrowWriter = ArrowWriter (this , targetSchema, mode, warningSubscriber)
89
134
90
135
/* *
91
136
* Save [dataFrame] content in Apache Arrow format (can be written to File, ByteArray, OutputStream or raw Channel) with [targetSchema].
92
137
* If [dataFrame] content does not match with [targetSchema], behaviour is specified by [mode]
93
138
*/
94
- public class ArrowWriter (private val dataFrame : DataFrame <* >, private val targetSchema : Schema , private val mode : Mode ): AutoCloseable {
139
+ public class ArrowWriter (
140
+ private val dataFrame : DataFrame <* >,
141
+ private val targetSchema : Schema ,
142
+ private val mode : Mode ,
143
+ private val warningSubscriber : (String ) -> Unit = writeWarningMessage
144
+ ): AutoCloseable {
95
145
96
146
public companion object {
97
147
/* *
@@ -143,15 +193,15 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
143
193
ArrowType .Int (16 , true ) -> column.convertTo<Short >()
144
194
ArrowType .Int (32 , true ) -> column.convertTo<Int >()
145
195
ArrowType .Int (64 , true ) -> column.convertTo<Long >()
146
- // ArrowType.Int(8, false), ArrowType.Int(16, false), ArrowType.Int(32, false), ArrowType.Int(64, false) ->
196
+ // ArrowType.Int(8, false), ArrowType.Int(16, false), ArrowType.Int(32, false), ArrowType.Int(64, false) -> todo
147
197
is ArrowType .Decimal -> column.convertToBigDecimal()
148
198
ArrowType .FloatingPoint (FloatingPointPrecision .SINGLE ) -> column.convertToFloat()
149
199
ArrowType .FloatingPoint (FloatingPointPrecision .DOUBLE ) -> column.convertToDouble()
150
200
ArrowType .Date (DateUnit .DAY ) -> column.convertToLocalDate()
151
201
ArrowType .Date (DateUnit .MILLISECOND ) -> column.convertToLocalDateTime()
152
202
is ArrowType .Time -> column.convertToLocalTime()
153
- // is ArrowType.Duration ->
154
- // is ArrowType.Struct ->
203
+ // is ArrowType.Duration -> todo
204
+ // is ArrowType.Struct -> todo
155
205
else -> {
156
206
TODO (" Saving ${targetFieldType.javaClass.canonicalName} is not implemented" )
157
207
}
@@ -162,30 +212,30 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
162
212
when (vector) {
163
213
is VarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text (value)); value} ? : vector.setNull(i) }
164
214
is LargeVarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text (value)); value} ? : vector.setNull(i) }
165
- // is VarBinaryVector -> vector.values(range).withType()
166
- // is LargeVarBinaryVector -> vector.values(range).withType()
215
+ // is VarBinaryVector -> todo
216
+ // is LargeVarBinaryVector -> todo
167
217
is BitVector -> column.convertToBoolean().forEachIndexed { i, value -> value?.let { vector.set(i, value.compareTo(false )); value} ? : vector.setNull(i) }
168
218
is TinyIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
169
219
is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
170
220
is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
171
221
is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
172
- // is UInt1Vector -> vector.values(range).withType()
173
- // is UInt2Vector -> vector.values(range).withType()
174
- // is UInt4Vector -> vector.values(range).withType()
175
- // is UInt8Vector -> vector.values(range).withType()
222
+ // is UInt1Vector -> todo
223
+ // is UInt2Vector -> todo
224
+ // is UInt4Vector -> todo
225
+ // is UInt8Vector -> todo
176
226
is DecimalVector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
177
227
is Decimal256Vector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
178
228
is Float8Vector -> column.convertToDouble().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
179
229
is Float4Vector -> column.convertToFloat().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ? : vector.setNull(i) }
180
230
181
231
is DateDayVector -> column.convertToLocalDate().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toJavaLocalDate().toEpochDay()).toInt()); value} ? : vector.setNull(i) }
182
232
is DateMilliVector -> column.convertToLocalDateTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toInstant(TimeZone .UTC ).toEpochMilliseconds()); value} ? : vector.setNull(i) }
183
- // is DurationVector -> vector.values(range).withType()
233
+ // is DurationVector -> todo
184
234
is TimeNanoVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay()); value} ? : vector.setNull(i) }
185
235
is TimeMicroVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay() / 1000 ); value} ? : vector.setNull(i) }
186
236
is TimeMilliVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000 ).toInt()); value} ? : vector.setNull(i) }
187
237
is TimeSecVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000 / 1000 ).toInt()); value} ? : vector.setNull(i) }
188
- // is StructVector -> vector.values(range).withType()
238
+ // is StructVector -> todo
189
239
else -> {
190
240
TODO (" Saving to ${vector.javaClass.canonicalName} is not implemented" )
191
241
}
@@ -206,7 +256,8 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
206
256
throw e
207
257
} else {
208
258
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
209
- val actualType = listOf (column!! ).toArrowSchema().fields.first().fieldType.type
259
+ warningSubscriber(e.message)
260
+ val actualType = listOf (column!! ).toArrowSchema(warningSubscriber).fields.first().fieldType.type
210
261
val actualField = Field (field.name, FieldType (field.isNullable, actualType, field.fieldType.dictionary), field.children)
211
262
column to actualField
212
263
}
@@ -215,6 +266,7 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
215
266
if (strictNullable) {
216
267
throw Exception (" ${actualField.name} column contains nulls but should be not nullable" )
217
268
} else {
269
+ warningSubscriber(" ${actualField.name} column contains nulls but expected not nullable" )
218
270
Field (actualField.name, FieldType (true , actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
219
271
}
220
272
} else {
@@ -231,7 +283,7 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
231
283
return vector
232
284
}
233
285
234
- private fun List<AnyCol>.toVectors (): List <FieldVector > = this .toArrowSchema().fields.mapIndexed { i, field ->
286
+ private fun List<AnyCol>.toVectors (): List <FieldVector > = this .toArrowSchema(warningSubscriber ).fields.mapIndexed { i, field ->
235
287
allocateVectorAndInfill(field, this [i], true , true )
236
288
}
237
289
@@ -246,6 +298,7 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
246
298
if (mode.restrictNarrowing) {
247
299
throw Exception (" ${field.name} column is not presented" )
248
300
} else {
301
+ warningSubscriber(" ${field.name} column is not presented" )
249
302
continue
250
303
}
251
304
}
@@ -255,9 +308,11 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
255
308
}
256
309
val vectors = ArrayList <FieldVector >()
257
310
vectors.addAll(mainVectors.values)
311
+ val otherVectors = dataFrame.columns().filter { column -> ! mainVectors.containsKey(column.name()) }.toVectors()
258
312
if (! mode.restrictWidening) {
259
- val otherVectors = dataFrame.columns().filter { column -> ! mainVectors.containsKey(column.name()) }.toVectors()
260
313
vectors.addAll(otherVectors)
314
+ } else {
315
+ otherVectors.forEach { warningSubscriber(" ${it.name} column is not described in target schema and was ignored" ) }
261
316
}
262
317
return VectorSchemaRoot (vectors)
263
318
}
0 commit comments