Skip to content

Commit 6f38c57

Browse files
committed
Sending detailed ConvertingMismatch on saving to Arrow
1 parent 9e2a1ab commit 6f38c57

File tree

6 files changed

+136
-45
lines changed

6 files changed

+136
-45
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/exceptions/TypeConverterNotFoundException.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.jetbrains.kotlinx.dataframe.exceptions
22

3-
import kotlin.reflect.*
3+
import kotlin.reflect.KType
44

55
public class TypeConverterNotFoundException(public val from: KType, public val to: KType) : IllegalArgumentException() {
66

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ import org.apache.arrow.vector.util.Text
4040
import org.jetbrains.kotlinx.dataframe.AnyCol
4141
import org.jetbrains.kotlinx.dataframe.AnyFrame
4242
import org.jetbrains.kotlinx.dataframe.DataFrame
43+
import org.jetbrains.kotlinx.dataframe.name
44+
import org.jetbrains.kotlinx.dataframe.typeClass
4345
import org.jetbrains.kotlinx.dataframe.api.convertToBigDecimal
4446
import org.jetbrains.kotlinx.dataframe.api.convertToBoolean
4547
import org.jetbrains.kotlinx.dataframe.api.convertToByte
@@ -54,9 +56,8 @@ import org.jetbrains.kotlinx.dataframe.api.convertToShort
5456
import org.jetbrains.kotlinx.dataframe.api.convertToString
5557
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
5658
import org.jetbrains.kotlinx.dataframe.api.map
57-
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
59+
import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException
5860
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
59-
import org.jetbrains.kotlinx.dataframe.typeClass
6061
import org.slf4j.LoggerFactory
6162
import java.io.ByteArrayOutputStream
6263
import java.io.File
@@ -70,14 +71,14 @@ import java.time.LocalTime
7071
import kotlin.reflect.full.isSubtypeOf
7172
import kotlin.reflect.typeOf
7273

73-
public val ignoreWarningMessage: (String) -> Unit = { message: String -> }
74-
public val writeWarningMessage: (String) -> Unit = { message: String -> System.err.println(message) }
74+
public val ignoreMismatchMessage: (ConvertingMismatch) -> Unit = { message: ConvertingMismatch -> }
75+
public val writeMismatchMessage: (ConvertingMismatch) -> Unit = { message: ConvertingMismatch -> System.err.println(message) }
7576

7677
private val logger = LoggerFactory.getLogger(ArrowWriter::class.java)
7778

78-
public val logWarningMessage: (String) -> Unit = { message: String -> logger.debug(message) }
79+
public val logMismatchMessage: (ConvertingMismatch) -> Unit = { message: ConvertingMismatch -> logger.debug(message.toString()) }
7980

80-
public fun AnyCol.toArrowField(warningSubscriber: (String) -> Unit = ignoreWarningMessage): Field {
81+
public fun AnyCol.toArrowField(mismatchSubscriber: (ConvertingMismatch) -> Unit = ignoreMismatchMessage): Field {
8182
val column = this
8283
val columnType = column.type()
8384
val nullable = columnType.isMarkedNullable
@@ -149,7 +150,7 @@ public fun AnyCol.toArrowField(warningSubscriber: (String) -> Unit = ignoreWarni
149150
)
150151

151152
else -> {
152-
warningSubscriber("Column ${column.name()} has type ${column.typeClass.java.canonicalName}, will be saved as String")
153+
mismatchSubscriber(ConvertingMismatch.SavedAsString(column.name(), column.typeClass.java))
153154
Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
154155
}
155156
}
@@ -158,8 +159,8 @@ public fun AnyCol.toArrowField(warningSubscriber: (String) -> Unit = ignoreWarni
158159
* Create Arrow [Schema] matching [this] actual data.
159160
* Columns with not supported types will be interpreted as String
160161
*/
161-
public fun List<AnyCol>.toArrowSchema(warningSubscriber: (String) -> Unit = ignoreWarningMessage): Schema {
162-
val fields = this.map { it.toArrowField(warningSubscriber) }
162+
public fun List<AnyCol>.toArrowSchema(mismatchSubscriber: (ConvertingMismatch) -> Unit = ignoreMismatchMessage): Schema {
163+
val fields = this.map { it.toArrowField(mismatchSubscriber) }
163164
return Schema(fields)
164165
}
165166

@@ -174,8 +175,8 @@ public fun DataFrame<*>.arrowWriter(): ArrowWriter = this.arrowWriter(this.colum
174175
public fun DataFrame<*>.arrowWriter(
175176
targetSchema: Schema,
176177
mode: ArrowWriter.Companion.Mode = ArrowWriter.Companion.Mode.STRICT,
177-
warningSubscriber: (String) -> Unit = ignoreWarningMessage
178-
): ArrowWriter = ArrowWriter(this, targetSchema, mode, warningSubscriber)
178+
mismatchSubscriber: (ConvertingMismatch) -> Unit = ignoreMismatchMessage
179+
): ArrowWriter = ArrowWriter(this, targetSchema, mode, mismatchSubscriber)
179180

180181
/**
181182
* Save [dataFrame] content in Apache Arrow format (can be written to File, ByteArray, OutputStream or raw Channel) with [targetSchema].
@@ -185,7 +186,7 @@ public class ArrowWriter(
185186
private val dataFrame: DataFrame<*>,
186187
private val targetSchema: Schema,
187188
private val mode: Mode,
188-
private val warningSubscriber: (String) -> Unit = ignoreWarningMessage
189+
private val mismatchSubscriber: (ConvertingMismatch) -> Unit = ignoreMismatchMessage
189190
) : AutoCloseable {
190191

191192
public companion object {
@@ -296,32 +297,50 @@ public class ArrowWriter(
296297
val containNulls = (column == null || column.hasNulls())
297298
// Convert the column to type specified in field. (If we already have target type, convertTo will do nothing)
298299

299-
fun handleConversionFail(e: Exception): Pair<AnyCol?, Field> {
300+
val (convertedColumn, actualField) = try {
301+
convertColumnToTarget(column, field.type) to field
302+
} catch (e: CellConversionException) {
300303
if (strictType) {
301304
// If conversion failed but strictType is enabled, throw the exception
302-
throw e
305+
val mismatch = ConvertingMismatch.TypeConversionFail.ConversionFailError(e.column, e.row, e)
306+
mismatchSubscriber(mismatch)
307+
throw ConvertingException(mismatch)
303308
} else {
304309
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
305-
warningSubscriber(e.message!!)
306-
val actualType = listOf(column!!).toArrowSchema(warningSubscriber).fields.first().fieldType.type
310+
mismatchSubscriber(ConvertingMismatch.TypeConversionFail.ConversionFailIgnored(e.column, e.row, e))
311+
val actualType = listOf(column!!).toArrowSchema(mismatchSubscriber).fields.first().fieldType.type
307312
val actualField = Field(field.name, FieldType(field.isNullable, actualType, field.fieldType.dictionary), field.children)
308-
return column to actualField
313+
column to actualField
309314
}
310-
}
311-
312-
val (convertedColumn, actualField) = try {
313-
convertColumnToTarget(column, field.type) to field
314-
} catch (e: TypeConversionException) {
315-
handleConversionFail(e)
316315
} catch (e: TypeConverterNotFoundException) {
317-
handleConversionFail(e)
316+
if (strictType) {
317+
// If conversion failed but strictType is enabled, throw the exception
318+
val mismatch = ConvertingMismatch.TypeConversionNotFound.ConversionNotFoundError(field.name, e)
319+
mismatchSubscriber(mismatch)
320+
throw ConvertingException(mismatch)
321+
} else {
322+
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
323+
mismatchSubscriber(ConvertingMismatch.TypeConversionNotFound.ConversionNotFoundIgnored(field.name, e))
324+
val actualType = listOf(column!!).toArrowSchema(mismatchSubscriber).fields.first().fieldType.type
325+
val actualField = Field(field.name, FieldType(field.isNullable, actualType, field.fieldType.dictionary), field.children)
326+
column to actualField
327+
}
318328
}
319329

320330
val vector = if (!actualField.isNullable && containNulls) {
331+
var firstNullValue: Int? = null;
332+
for (i in 0 until (column?.size() ?: -1)) {
333+
if (column!![i] == null) {
334+
firstNullValue = i;
335+
break;
336+
}
337+
}
321338
if (strictNullable) {
322-
throw IllegalArgumentException("Column \"${actualField.name}\" contains nulls but should be not nullable")
339+
val mismatch = ConvertingMismatch.NullableMismatch.NullValueError(actualField.name, firstNullValue)
340+
mismatchSubscriber(mismatch)
341+
throw ConvertingException(mismatch)
323342
} else {
324-
warningSubscriber("Column \"${actualField.name}\" contains nulls but expected not nullable")
343+
mismatchSubscriber(ConvertingMismatch.NullableMismatch.NullValueIgnored(actualField.name, firstNullValue))
325344
Field(actualField.name, FieldType(true, actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
326345
}
327346
} else {
@@ -339,7 +358,7 @@ public class ArrowWriter(
339358
}
340359

341360
private fun List<AnyCol>.toVectors(): List<FieldVector> = this.map {
342-
val field = it.toArrowField(warningSubscriber)
361+
val field = it.toArrowField(mismatchSubscriber)
343362
allocateVectorAndInfill(field, it, true, true)
344363
}
345364
/**
@@ -352,9 +371,11 @@ public class ArrowWriter(
352371
val column = dataFrame.getColumnOrNull(field.name)
353372
if (column == null && !field.isNullable) {
354373
if (mode.restrictNarrowing) {
355-
throw IllegalArgumentException("Column \"${field.name}\" is not presented")
374+
val mismatch = ConvertingMismatch.NarrowingMismatch.NotPresentedColumnError(field.name)
375+
mismatchSubscriber(mismatch)
376+
throw ConvertingException(mismatch)
356377
} else {
357-
warningSubscriber("Column \"${field.name}\" is not presented")
378+
mismatchSubscriber(ConvertingMismatch.NarrowingMismatch.NotPresentedColumnIgnored(field.name))
358379
continue
359380
}
360381
}
@@ -371,9 +392,12 @@ public class ArrowWriter(
371392
val otherColumns = dataFrame.columns().filter { column -> !mainVectors.containsKey(column.name()) }
372393
if (!mode.restrictWidening) {
373394
vectors.addAll(otherColumns.toVectors())
395+
otherColumns.forEach {
396+
mismatchSubscriber(ConvertingMismatch.WideningMismatch.AddedColumn(it.name))
397+
}
374398
} else {
375399
otherColumns.forEach {
376-
warningSubscriber("Column \"${it.name()}\" is not described in target schema and was ignored")
400+
mismatchSubscriber(ConvertingMismatch.WideningMismatch.RejectedColumn(it.name))
377401
}
378402
}
379403
return VectorSchemaRoot(vectors)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package org.jetbrains.kotlinx.dataframe.io
2+
3+
import org.jetbrains.kotlinx.dataframe.exceptions.CellConversionException
4+
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
5+
6+
/**
7+
* Detailed message about any mismatch when saving to Arrow format with user-defined schema that does not match with actual data.
8+
* Can be sent to callback, written to log or encapsulated to exception
9+
*/
10+
public sealed class ConvertingMismatch(
11+
/**Name of the column with mismatch*/
12+
public open val column: String,
13+
/**Number of first row with mismatch (0-based) if defined*/
14+
public open val row: Int?,
15+
/**Original exception if exist*/
16+
public open val cause: Exception?
17+
) {
18+
19+
public sealed class WideningMismatch(column: String): ConvertingMismatch(column, null, null) {
20+
public data class AddedColumn(override val column: String): WideningMismatch(column) {
21+
override fun toString(): String = "Added column \"$column\" not described in target schema"
22+
}
23+
public data class RejectedColumn(override val column: String): WideningMismatch(column) {
24+
override fun toString(): String = "Column \"$column\" is not described in target schema and was ignored"
25+
}
26+
}
27+
public sealed class NarrowingMismatch(column: String): ConvertingMismatch(column, null, null) {
28+
public data class NotPresentedColumnIgnored(override val column: String): NarrowingMismatch(column) {
29+
override fun toString(): String = "Not nullable column \"$column\" is not presented in actual data, saving as is"
30+
}
31+
public data class NotPresentedColumnError(override val column: String): NarrowingMismatch(column) {
32+
override fun toString(): String = "Not nullable column \"$column\" is not presented in actual data, can not save"
33+
}
34+
}
35+
public sealed class TypeConversionNotFound(column: String, cause: TypeConverterNotFoundException): ConvertingMismatch(column, null, cause) {
36+
public data class ConversionNotFoundIgnored(override val column: String, override val cause: TypeConverterNotFoundException): TypeConversionNotFound(column, cause) {
37+
override fun toString(): String = "${cause.message} for column \"$column\", saving as is"
38+
}
39+
public data class ConversionNotFoundError(override val column: String, val e: TypeConverterNotFoundException): TypeConversionNotFound(column, e) {
40+
override fun toString(): String = "${e.message} for column \"$column\", can not save"
41+
}
42+
}
43+
public sealed class TypeConversionFail(column: String, row: Int?, public override val cause: CellConversionException): ConvertingMismatch(column, row, cause) {
44+
public data class ConversionFailIgnored(override val column: String, override val row: Int?, override val cause: CellConversionException): TypeConversionFail(column, row, cause) {
45+
override fun toString(): String = "${cause.message}, saving as is"
46+
}
47+
public data class ConversionFailError(override val column: String, override val row: Int?, override val cause: CellConversionException): TypeConversionFail(column, row, cause) {
48+
override fun toString(): String = "${cause.message}, can not save"
49+
}
50+
}
51+
public data class SavedAsString(override val column: String, val type: Class<*>): ConvertingMismatch(column, null, null) {
52+
override fun toString(): String = "Column \"$column\" has type ${type.canonicalName}, will be saved as String\""
53+
}
54+
public sealed class NullableMismatch(column: String, row: Int?): ConvertingMismatch(column, row, null) {
55+
public data class NullValueIgnored(override val column: String, override val row: Int?): NullableMismatch(column, row) {
56+
override fun toString(): String = "Column \"$column\" contains nulls in row $row but expected not nullable, saving as is"
57+
}
58+
public data class NullValueError(override val column: String, override val row: Int?): NullableMismatch(column, row) {
59+
override fun toString(): String = "Column \"$column\" contains nulls in row $row but expected not nullable, can not save"
60+
}
61+
}
62+
}
63+
64+
public class ConvertingException(public val mismatchCase: ConvertingMismatch): IllegalArgumentException(mismatchCase.toString(), mismatchCase.cause)

dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
1616
import org.jetbrains.kotlinx.dataframe.api.map
1717
import org.jetbrains.kotlinx.dataframe.api.remove
1818
import org.jetbrains.kotlinx.dataframe.api.toColumn
19+
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
1920
import org.junit.Test
2021
import java.io.File
2122
import java.net.URL
@@ -137,12 +138,12 @@ internal class ArrowKtTest {
137138

138139
@Test
139140
fun testWidening() {
140-
val warnings = ArrayList<String>()
141+
val warnings = ArrayList<ConvertingMismatch>()
141142
val testRestrictWidening = citiesExampleFrame.arrowWriter(
142143
Schema.fromJSON(citiesExampleSchema),
143144
ArrowWriter.Companion.Mode.STRICT
144145
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
145-
warnings.shouldContain("Column \"page_in_wiki\" is not described in target schema and was ignored")
146+
warnings.shouldContain(ConvertingMismatch.WideningMismatch.RejectedColumn("page_in_wiki"))
146147
shouldThrow<IllegalArgumentException> { DataFrame.readArrowFeather(testRestrictWidening)["page_in_wiki"] }
147148

148149
val testAllowWidening = citiesExampleFrame.arrowWriter(
@@ -165,10 +166,10 @@ internal class ArrowKtTest {
165166
Schema.fromJSON(citiesExampleSchema),
166167
ArrowWriter.Companion.Mode.STRICT
167168
).use {
168-
shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() }
169+
shouldThrow<ConvertingException> { it.saveArrowFeatherToByteArray() }
169170
}
170171

171-
val warnings = ArrayList<String>()
172+
val warnings = ArrayList<ConvertingMismatch>()
172173
val testAllowNarrowing = frameWithoutRequiredField.arrowWriter(
173174
Schema.fromJSON(citiesExampleSchema),
174175
ArrowWriter.Companion.Mode(
@@ -178,7 +179,7 @@ internal class ArrowKtTest {
178179
strictNullable = true
179180
)
180181
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
181-
warnings.shouldContain("Column \"settled\" is not presented")
182+
warnings.shouldContain( ConvertingMismatch.NarrowingMismatch.NotPresentedColumnIgnored("settled"))
182183
shouldThrow<IllegalArgumentException> { DataFrame.readArrowFeather(testAllowNarrowing)["settled"] }
183184
}
184185

@@ -191,10 +192,10 @@ internal class ArrowKtTest {
191192
Schema.fromJSON(citiesExampleSchema),
192193
ArrowWriter.Companion.Mode.STRICT
193194
).use {
194-
shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() }
195+
shouldThrow<ConvertingException> { it.saveArrowFeatherToByteArray() }
195196
}
196197

197-
val warnings = ArrayList<String>()
198+
val warnings = ArrayList<ConvertingMismatch>()
198199
val testLoyalType = frameWithIncompatibleField.arrowWriter(
199200
Schema.fromJSON(citiesExampleSchema),
200201
ArrowWriter.Companion.Mode(
@@ -204,7 +205,9 @@ internal class ArrowKtTest {
204205
strictNullable = true
205206
)
206207
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
207-
warnings.shouldContain("Type converter from kotlin.Boolean to kotlinx.datetime.LocalDateTime? is not found")
208+
warnings.map { it.toString() }.shouldContain(
209+
ConvertingMismatch.TypeConversionNotFound.ConversionNotFoundIgnored("settled", TypeConverterNotFoundException(typeOf<Boolean>(), typeOf<kotlinx.datetime.LocalDateTime?>())).toString()
210+
)
208211
DataFrame.readArrowFeather(testLoyalType)["settled"].type() shouldBe typeOf<Boolean>()
209212
}
210213

@@ -217,10 +220,10 @@ internal class ArrowKtTest {
217220
Schema.fromJSON(citiesExampleSchema),
218221
ArrowWriter.Companion.Mode.STRICT
219222
).use {
220-
shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() }
223+
shouldThrow<ConvertingException> { it.saveArrowFeatherToByteArray() }
221224
}
222225

223-
val warnings = ArrayList<String>()
226+
val warnings = ArrayList<ConvertingMismatch>()
224227
val testLoyalNullable = frameWithNulls.arrowWriter(
225228
Schema.fromJSON(citiesExampleSchema),
226229
ArrowWriter.Companion.Mode(
@@ -230,7 +233,7 @@ internal class ArrowKtTest {
230233
strictNullable = false
231234
)
232235
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
233-
warnings.shouldContain("Column \"settled\" contains nulls but expected not nullable")
236+
warnings.shouldContain(ConvertingMismatch.NullableMismatch.NullValueIgnored("settled", 0))
234237
DataFrame.readArrowFeather(testLoyalNullable)["settled"].type() shouldBe typeOf<LocalDateTime?>()
235238
DataFrame.readArrowFeather(testLoyalNullable)["settled"].values() shouldBe arrayOfNulls<LocalDate>(frameRenaming.rowsCount()).asList()
236239
}

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ openapi = "2.1.2"
2323
junit = "4.13.2"
2424
kotestAsserions = "4.6.3"
2525
jsoup = "1.14.3"
26-
arrow = "9.0.0"
26+
arrow = "10.0.0"
2727

2828
[libraries]
2929
ksp-gradle = { group = "com.google.devtools.ksp", name = "symbol-processing-gradle-plugin", version.ref = "ksp" }

0 commit comments

Comments
 (0)