Skip to content

Commit 20451d6

Browse files
committed
Saving to Arrow strict / loyal tests
1 parent a6c51d8 commit 20451d6

File tree

2 files changed

+151
-25
lines changed

2 files changed

+151
-25
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import org.jetbrains.kotlinx.dataframe.api.convertToLocalDateTime
5353
import org.jetbrains.kotlinx.dataframe.api.convertToString
5454
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
5555
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
56+
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
5657
import org.jetbrains.kotlinx.dataframe.typeClass
5758
import org.slf4j.Logger
5859
import org.slf4j.LoggerFactory
@@ -244,25 +245,33 @@ public class ArrowWriter(
244245
private fun allocateVectorAndInfill(field: Field, column: AnyCol?, strictType: Boolean, strictNullable: Boolean): FieldVector {
245246
val containNulls = (column == null || column.hasNulls())
246247
// Convert the column to type specified in field. (If we already have target type, convertTo will do nothing)
247-
val (convertedColumn, actualField) = try {
248-
convertColumnToTarget(column, field.type) to field
249-
} catch (e: TypeConversionException) {
248+
249+
fun handleConversionFail(e: Exception): Pair<AnyCol?, Field> {
250250
if (strictType) {
251251
// If conversion failed but strictType is enabled, throw the exception
252252
throw e
253253
} else {
254254
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
255-
warningSubscriber(e.message)
255+
warningSubscriber(e.message!!)
256256
val actualType = listOf(column!!).toArrowSchema(warningSubscriber).fields.first().fieldType.type
257257
val actualField = Field(field.name, FieldType(field.isNullable, actualType, field.fieldType.dictionary), field.children)
258-
column to actualField
258+
return column to actualField
259259
}
260260
}
261+
262+
val (convertedColumn, actualField) = try {
263+
convertColumnToTarget(column, field.type) to field
264+
} catch (e: TypeConversionException) {
265+
handleConversionFail(e)
266+
} catch (e: TypeConverterNotFoundException) {
267+
handleConversionFail(e)
268+
}
269+
261270
val vector = if (!actualField.isNullable && containNulls) {
262271
if (strictNullable) {
263-
throw IllegalArgumentException("${actualField.name} column contains nulls but should be not nullable")
272+
throw IllegalArgumentException("Column \"${actualField.name}\" contains nulls but should be not nullable")
264273
} else {
265-
warningSubscriber("${actualField.name} column contains nulls but expected not nullable")
274+
warningSubscriber("Column \"${actualField.name}\" contains nulls but expected not nullable")
266275
Field(actualField.name, FieldType(true, actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
267276
}
268277
} else {
@@ -288,27 +297,32 @@ public class ArrowWriter(
288297
*/
289298
private fun allocateVectorSchemaRoot(): VectorSchemaRoot {
290299
val mainVectors = LinkedHashMap<String, FieldVector>()
291-
for (field in targetSchema.fields) {
292-
val column = dataFrame.getColumnOrNull(field.name)
293-
if (column == null && !field.isNullable) {
294-
if (mode.restrictNarrowing) {
295-
throw IllegalArgumentException("${field.name} column is not presented")
296-
} else {
297-
warningSubscriber("${field.name} column is not presented")
298-
continue
300+
try {
301+
for (field in targetSchema.fields) {
302+
val column = dataFrame.getColumnOrNull(field.name)
303+
if (column == null && !field.isNullable) {
304+
if (mode.restrictNarrowing) {
305+
throw IllegalArgumentException("Column \"${field.name}\" is not presented")
306+
} else {
307+
warningSubscriber("Column \"${field.name}\" is not presented")
308+
continue
309+
}
299310
}
300-
}
301311

302-
val vector = allocateVectorAndInfill(field, column, mode.strictType, mode.strictNullable)
303-
mainVectors[field.name] = vector
312+
val vector = allocateVectorAndInfill(field, column, mode.strictType, mode.strictNullable)
313+
mainVectors[field.name] = vector
314+
}
315+
} catch (e: Exception) {
316+
mainVectors.values.forEach { it.close() } //Clear buffers before throwing exception
317+
throw e
304318
}
305319
val vectors = ArrayList<FieldVector>()
306320
vectors.addAll(mainVectors.values)
307-
val otherVectors = dataFrame.columns().filter { column -> !mainVectors.containsKey(column.name()) }.toVectors()
321+
val otherColumns = dataFrame.columns().filter { column -> !mainVectors.containsKey(column.name()) }
308322
if (!mode.restrictWidening) {
309-
vectors.addAll(otherVectors)
323+
vectors.addAll(otherColumns.toVectors())
310324
} else {
311-
otherVectors.forEach { warningSubscriber("${it.name} column is not described in target schema and was ignored") }
325+
otherColumns.forEach { warningSubscriber("Column \"${it.name()}\" is not described in target schema and was ignored") }
312326
}
313327
return VectorSchemaRoot(vectors)
314328
}

dataframe-arrow/src/test/kotlin/ArrowKtTest.kt

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
import io.kotest.assertions.throwables.shouldThrow
2+
import io.kotest.matchers.collections.shouldContain
23
import io.kotest.matchers.shouldBe
34
import org.apache.arrow.vector.types.pojo.Schema
45
import org.apache.arrow.vector.util.Text
56
import org.jetbrains.kotlinx.dataframe.DataColumn
67
import org.jetbrains.kotlinx.dataframe.DataFrame
78
import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions
9+
import org.jetbrains.kotlinx.dataframe.api.add
810
import org.jetbrains.kotlinx.dataframe.api.columnOf
11+
import org.jetbrains.kotlinx.dataframe.api.convertToBoolean
12+
import org.jetbrains.kotlinx.dataframe.api.copy
913
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
14+
import org.jetbrains.kotlinx.dataframe.api.map
1015
import org.jetbrains.kotlinx.dataframe.api.toColumn
11-
import org.jetbrains.kotlinx.dataframe.io.*
16+
import org.jetbrains.kotlinx.dataframe.api.remove
17+
import org.jetbrains.kotlinx.dataframe.io.ArrowWriter
18+
import org.jetbrains.kotlinx.dataframe.io.arrowWriter
19+
import org.jetbrains.kotlinx.dataframe.io.readArrowFeather
20+
import org.jetbrains.kotlinx.dataframe.io.readArrowIPC
21+
import org.jetbrains.kotlinx.dataframe.io.writeArrowFeather
22+
import org.jetbrains.kotlinx.dataframe.io.saveArrowIPCToByteArray
1223
import org.junit.Test
1324
import java.io.File
1425
import java.net.URL
@@ -112,19 +123,120 @@ internal class ArrowKtTest {
112123
citiesExampleFrame.writeArrowFeather(testFile)
113124
assertEstimation(DataFrame.readArrowFeather(testFile))
114125

115-
val testByteArray = citiesExampleFrame.arrowWriter().saveArrowIPCToByteArray()
126+
val testByteArray = citiesExampleFrame.saveArrowIPCToByteArray()
116127
assertEstimation(DataFrame.readArrowIPC(testByteArray))
117128
}
118129

119130
@Test
120131
fun testWritingBySchema() {
121132
val testFile = File.createTempFile("cities", "arrow")
122-
citiesExampleFrame.arrowWriter(Schema.fromJSON(citiesExampleSchema)).writeArrowFeather(testFile)
133+
citiesExampleFrame.arrowWriter(Schema.fromJSON(citiesExampleSchema)).use { it.writeArrowFeather(testFile) }
123134
val citiesDeserialized = DataFrame.readArrowFeather(testFile, NullabilityOptions.Checking)
124135
citiesDeserialized["population"].type() shouldBe typeOf<Long?>()
125136
citiesDeserialized["area"].type() shouldBe typeOf<Float>()
126137
citiesDeserialized["settled"].type() shouldBe typeOf<LocalDateTime>()
127-
shouldThrow<IllegalArgumentException> { citiesDeserialized["page_in_wiki"] shouldBe null }
138+
shouldThrow<IllegalArgumentException> { citiesDeserialized["page_in_wiki"] }
128139
citiesDeserialized["film_in_youtube"] shouldBe DataColumn.createValueColumn("film_in_youtube", arrayOfNulls<String>(citiesExampleFrame.rowsCount()).asList())
129140
}
141+
142+
@Test
143+
fun testWidening() {
144+
val warnings = ArrayList<String>()
145+
val testRestrictWidening = citiesExampleFrame.arrowWriter(
146+
Schema.fromJSON(citiesExampleSchema),
147+
ArrowWriter.Companion.Mode.STRICT
148+
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
149+
warnings.shouldContain("Column \"page_in_wiki\" is not described in target schema and was ignored")
150+
shouldThrow<IllegalArgumentException> { DataFrame.readArrowFeather(testRestrictWidening)["page_in_wiki"] }
151+
152+
val testAllowWidening = citiesExampleFrame.arrowWriter(
153+
Schema.fromJSON(citiesExampleSchema),
154+
ArrowWriter.Companion.Mode(
155+
restrictWidening = false,
156+
restrictNarrowing = true,
157+
strictType = true,
158+
strictNullable = true
159+
)
160+
).use { it.saveArrowFeatherToByteArray() }
161+
DataFrame.readArrowFeather(testAllowWidening)["page_in_wiki"].values() shouldBe citiesExampleFrame["page_in_wiki"].values().map { it.toString() }
162+
}
163+
164+
@Test
165+
fun testNarrowing() {
166+
val frameWithoutRequiredField = citiesExampleFrame.copy().remove("settled")
167+
168+
frameWithoutRequiredField.arrowWriter(
169+
Schema.fromJSON(citiesExampleSchema),
170+
ArrowWriter.Companion.Mode.STRICT
171+
).use {
172+
shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() }
173+
}
174+
175+
val warnings = ArrayList<String>()
176+
val testAllowNarrowing = frameWithoutRequiredField.arrowWriter(
177+
Schema.fromJSON(citiesExampleSchema),
178+
ArrowWriter.Companion.Mode(
179+
restrictWidening = true,
180+
restrictNarrowing = false,
181+
strictType = true,
182+
strictNullable = true
183+
)
184+
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
185+
warnings.shouldContain("Column \"settled\" is not presented")
186+
shouldThrow<IllegalArgumentException> { DataFrame.readArrowFeather(testAllowNarrowing)["settled"] }
187+
}
188+
189+
@Test
190+
fun testStrictType() {
191+
val frameRenaming = citiesExampleFrame.copy().remove("settled")
192+
val frameWithIncompatibleField = frameRenaming.add(frameRenaming["is_capital"].map { value -> value ?: false }.rename("settled").convertToBoolean())
193+
194+
frameWithIncompatibleField.arrowWriter(
195+
Schema.fromJSON(citiesExampleSchema),
196+
ArrowWriter.Companion.Mode.STRICT
197+
).use {
198+
shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() }
199+
}
200+
201+
val warnings = ArrayList<String>()
202+
val testLoyalType = frameWithIncompatibleField.arrowWriter(
203+
Schema.fromJSON(citiesExampleSchema),
204+
ArrowWriter.Companion.Mode(
205+
restrictWidening = true,
206+
restrictNarrowing = true,
207+
strictType = false,
208+
strictNullable = true
209+
)
210+
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
211+
warnings.shouldContain("Type converter from kotlin.Boolean to kotlinx.datetime.LocalDateTime? is not found")
212+
DataFrame.readArrowFeather(testLoyalType)["settled"].type() shouldBe typeOf<Boolean>()
213+
}
214+
215+
@Test
216+
fun testStrictNullable() {
217+
val frameRenaming = citiesExampleFrame.copy().remove("settled")
218+
val frameWithNulls = frameRenaming.add(DataColumn.createValueColumn("settled", arrayOfNulls<LocalDate>(frameRenaming.rowsCount()).asList()))
219+
220+
frameWithNulls.arrowWriter(
221+
Schema.fromJSON(citiesExampleSchema),
222+
ArrowWriter.Companion.Mode.STRICT
223+
).use {
224+
shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() }
225+
}
226+
227+
val warnings = ArrayList<String>()
228+
val testLoyalNullable = frameWithNulls.arrowWriter(
229+
Schema.fromJSON(citiesExampleSchema),
230+
ArrowWriter.Companion.Mode(
231+
restrictWidening = true,
232+
restrictNarrowing = true,
233+
strictType = true,
234+
strictNullable = false
235+
)
236+
) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() }
237+
warnings.shouldContain("Column \"settled\" contains nulls but expected not nullable")
238+
DataFrame.readArrowFeather(testLoyalNullable)["settled"].type() shouldBe typeOf<LocalDateTime?>()
239+
DataFrame.readArrowFeather(testLoyalNullable)["settled"].values() shouldBe arrayOfNulls<LocalDate>(frameRenaming.rowsCount()).asList()
240+
}
241+
130242
}

0 commit comments

Comments
 (0)