@@ -40,6 +40,8 @@ import org.apache.arrow.vector.util.Text
40
40
import org.jetbrains.kotlinx.dataframe.AnyCol
41
41
import org.jetbrains.kotlinx.dataframe.AnyFrame
42
42
import org.jetbrains.kotlinx.dataframe.DataFrame
43
+ import org.jetbrains.kotlinx.dataframe.name
44
+ import org.jetbrains.kotlinx.dataframe.typeClass
43
45
import org.jetbrains.kotlinx.dataframe.api.convertToBigDecimal
44
46
import org.jetbrains.kotlinx.dataframe.api.convertToBoolean
45
47
import org.jetbrains.kotlinx.dataframe.api.convertToByte
@@ -54,9 +56,8 @@ import org.jetbrains.kotlinx.dataframe.api.convertToShort
54
56
import org.jetbrains.kotlinx.dataframe.api.convertToString
55
57
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
56
58
import org.jetbrains.kotlinx.dataframe.api.map
57
- import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
59
+ import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException
58
60
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
59
- import org.jetbrains.kotlinx.dataframe.typeClass
60
61
import org.slf4j.LoggerFactory
61
62
import java.io.ByteArrayOutputStream
62
63
import java.io.File
@@ -70,14 +71,14 @@ import java.time.LocalTime
70
71
import kotlin.reflect.full.isSubtypeOf
71
72
import kotlin.reflect.typeOf
72
73
73
- public val ignoreWarningMessage : (String ) -> Unit = { message: String -> }
74
- public val writeWarningMessage : (String ) -> Unit = { message: String -> System .err.println (message) }
74
+ public val ignoreMismatchMessage : (ConvertingMismatch ) -> Unit = { message: ConvertingMismatch -> }
75
+ public val writeMismatchMessage : (ConvertingMismatch ) -> Unit = { message: ConvertingMismatch -> System .err.println (message) }
75
76
76
77
private val logger = LoggerFactory .getLogger(ArrowWriter ::class .java)
77
78
78
- public val logWarningMessage : (String ) -> Unit = { message: String -> logger.debug(message) }
79
+ public val logMismatchMessage : (ConvertingMismatch ) -> Unit = { message: ConvertingMismatch -> logger.debug(message.toString() ) }
79
80
80
- public fun AnyCol.toArrowField (warningSubscriber : (String ) -> Unit = ignoreWarningMessage ): Field {
81
+ public fun AnyCol.toArrowField (mismatchSubscriber : (ConvertingMismatch ) -> Unit = ignoreMismatchMessage ): Field {
81
82
val column = this
82
83
val columnType = column.type()
83
84
val nullable = columnType.isMarkedNullable
@@ -149,7 +150,7 @@ public fun AnyCol.toArrowField(warningSubscriber: (String) -> Unit = ignoreWarni
149
150
)
150
151
151
152
else -> {
152
- warningSubscriber( " Column ${ column.name()} has type ${ column.typeClass.java.canonicalName} , will be saved as String " )
153
+ mismatchSubscriber( ConvertingMismatch . SavedAsString ( column.name(), column.typeClass.java) )
153
154
Field (column.name(), FieldType (true , ArrowType .Utf8 (), null ), emptyList())
154
155
}
155
156
}
@@ -158,8 +159,8 @@ public fun AnyCol.toArrowField(warningSubscriber: (String) -> Unit = ignoreWarni
158
159
* Create Arrow [Schema] matching [this] actual data.
159
160
* Columns with not supported types will be interpreted as String
160
161
*/
161
- public fun List<AnyCol>.toArrowSchema (warningSubscriber : (String ) -> Unit = ignoreWarningMessage ): Schema {
162
- val fields = this .map { it.toArrowField(warningSubscriber ) }
162
+ public fun List<AnyCol>.toArrowSchema (mismatchSubscriber : (ConvertingMismatch ) -> Unit = ignoreMismatchMessage ): Schema {
163
+ val fields = this .map { it.toArrowField(mismatchSubscriber ) }
163
164
return Schema (fields)
164
165
}
165
166
@@ -174,8 +175,8 @@ public fun DataFrame<*>.arrowWriter(): ArrowWriter = this.arrowWriter(this.colum
174
175
public fun DataFrame <* >.arrowWriter (
175
176
targetSchema : Schema ,
176
177
mode : ArrowWriter .Companion .Mode = ArrowWriter .Companion .Mode .STRICT ,
177
- warningSubscriber : (String ) -> Unit = ignoreWarningMessage
178
- ): ArrowWriter = ArrowWriter (this , targetSchema, mode, warningSubscriber )
178
+ mismatchSubscriber : (ConvertingMismatch ) -> Unit = ignoreMismatchMessage
179
+ ): ArrowWriter = ArrowWriter (this , targetSchema, mode, mismatchSubscriber )
179
180
180
181
/* *
181
182
* Save [dataFrame] content in Apache Arrow format (can be written to File, ByteArray, OutputStream or raw Channel) with [targetSchema].
@@ -185,7 +186,7 @@ public class ArrowWriter(
185
186
private val dataFrame : DataFrame <* >,
186
187
private val targetSchema : Schema ,
187
188
private val mode : Mode ,
188
- private val warningSubscriber : (String ) -> Unit = ignoreWarningMessage
189
+ private val mismatchSubscriber : (ConvertingMismatch ) -> Unit = ignoreMismatchMessage
189
190
) : AutoCloseable {
190
191
191
192
public companion object {
@@ -296,32 +297,50 @@ public class ArrowWriter(
296
297
val containNulls = (column == null || column.hasNulls())
297
298
// Convert the column to type specified in field. (If we already have target type, convertTo will do nothing)
298
299
299
- fun handleConversionFail (e : Exception ): Pair <AnyCol ?, Field > {
300
+ val (convertedColumn, actualField) = try {
301
+ convertColumnToTarget(column, field.type) to field
302
+ } catch (e: CellConversionException ) {
300
303
if (strictType) {
301
304
// If conversion failed but strictType is enabled, throw the exception
302
- throw e
305
+ val mismatch = ConvertingMismatch .TypeConversionFail .ConversionFailError (e.column, e.row, e)
306
+ mismatchSubscriber(mismatch)
307
+ throw ConvertingException (mismatch)
303
308
} else {
304
309
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
305
- warningSubscriber(e.message !! )
306
- val actualType = listOf (column!! ).toArrowSchema(warningSubscriber ).fields.first().fieldType.type
310
+ mismatchSubscriber( ConvertingMismatch . TypeConversionFail . ConversionFailIgnored (e.column, e.row, e) )
311
+ val actualType = listOf (column!! ).toArrowSchema(mismatchSubscriber ).fields.first().fieldType.type
307
312
val actualField = Field (field.name, FieldType (field.isNullable, actualType, field.fieldType.dictionary), field.children)
308
- return column to actualField
313
+ column to actualField
309
314
}
310
- }
311
-
312
- val (convertedColumn, actualField) = try {
313
- convertColumnToTarget(column, field.type) to field
314
- } catch (e: TypeConversionException ) {
315
- handleConversionFail(e)
316
315
} catch (e: TypeConverterNotFoundException ) {
317
- handleConversionFail(e)
316
+ if (strictType) {
317
+ // If conversion failed but strictType is enabled, throw the exception
318
+ val mismatch = ConvertingMismatch .TypeConversionNotFound .ConversionNotFoundError (field.name, e)
319
+ mismatchSubscriber(mismatch)
320
+ throw ConvertingException (mismatch)
321
+ } else {
322
+ // If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
323
+ mismatchSubscriber(ConvertingMismatch .TypeConversionNotFound .ConversionNotFoundIgnored (field.name, e))
324
+ val actualType = listOf (column!! ).toArrowSchema(mismatchSubscriber).fields.first().fieldType.type
325
+ val actualField = Field (field.name, FieldType (field.isNullable, actualType, field.fieldType.dictionary), field.children)
326
+ column to actualField
327
+ }
318
328
}
319
329
320
330
val vector = if (! actualField.isNullable && containNulls) {
331
+ var firstNullValue: Int? = null ;
332
+ for (i in 0 until (column?.size() ? : - 1 )) {
333
+ if (column!! [i] == null ) {
334
+ firstNullValue = i;
335
+ break ;
336
+ }
337
+ }
321
338
if (strictNullable) {
322
- throw IllegalArgumentException (" Column \" ${actualField.name} \" contains nulls but should be not nullable" )
339
+ val mismatch = ConvertingMismatch .NullableMismatch .NullValueError (actualField.name, firstNullValue)
340
+ mismatchSubscriber(mismatch)
341
+ throw ConvertingException (mismatch)
323
342
} else {
324
- warningSubscriber( " Column \" ${ actualField.name} \" contains nulls but expected not nullable " )
343
+ mismatchSubscriber( ConvertingMismatch . NullableMismatch . NullValueIgnored ( actualField.name, firstNullValue) )
325
344
Field (actualField.name, FieldType (true , actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
326
345
}
327
346
} else {
@@ -339,7 +358,7 @@ public class ArrowWriter(
339
358
}
340
359
341
360
private fun List<AnyCol>.toVectors (): List <FieldVector > = this .map {
342
- val field = it.toArrowField(warningSubscriber )
361
+ val field = it.toArrowField(mismatchSubscriber )
343
362
allocateVectorAndInfill(field, it, true , true )
344
363
}
345
364
/* *
@@ -352,9 +371,11 @@ public class ArrowWriter(
352
371
val column = dataFrame.getColumnOrNull(field.name)
353
372
if (column == null && ! field.isNullable) {
354
373
if (mode.restrictNarrowing) {
355
- throw IllegalArgumentException (" Column \" ${field.name} \" is not presented" )
374
+ val mismatch = ConvertingMismatch .NarrowingMismatch .NotPresentedColumnError (field.name)
375
+ mismatchSubscriber(mismatch)
376
+ throw ConvertingException (mismatch)
356
377
} else {
357
- warningSubscriber( " Column \" ${ field.name} \" is not presented " )
378
+ mismatchSubscriber( ConvertingMismatch . NarrowingMismatch . NotPresentedColumnIgnored ( field.name) )
358
379
continue
359
380
}
360
381
}
@@ -371,9 +392,12 @@ public class ArrowWriter(
371
392
val otherColumns = dataFrame.columns().filter { column -> ! mainVectors.containsKey(column.name()) }
372
393
if (! mode.restrictWidening) {
373
394
vectors.addAll(otherColumns.toVectors())
395
+ otherColumns.forEach {
396
+ mismatchSubscriber(ConvertingMismatch .WideningMismatch .AddedColumn (it.name))
397
+ }
374
398
} else {
375
399
otherColumns.forEach {
376
- warningSubscriber( " Column \" ${ it.name()} \" is not described in target schema and was ignored " )
400
+ mismatchSubscriber( ConvertingMismatch . WideningMismatch . RejectedColumn ( it.name) )
377
401
}
378
402
}
379
403
return VectorSchemaRoot (vectors)
0 commit comments