Skip to content

Commit 21a1749

Browse files
committed
NullabilityOptions
1 parent 8ccd85b commit 21a1749

File tree

4 files changed

+180
-91
lines changed

4 files changed

+180
-91
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/TypeConversions.kt

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnSet
1717
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
1818
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
1919
import org.jetbrains.kotlinx.dataframe.impl.GroupByImpl
20+
import org.jetbrains.kotlinx.dataframe.impl.anyNull
2021
import org.jetbrains.kotlinx.dataframe.impl.asList
2122
import org.jetbrains.kotlinx.dataframe.impl.columnName
2223
import org.jetbrains.kotlinx.dataframe.impl.columns.ColumnAccessorImpl
@@ -189,6 +190,50 @@ public enum class Infer {
189190
Type
190191
}
191192

193+
/**
194+
* Indicates how [DataColumn.hasNulls] (or, more accurately, DataColumn.type.isMarkedNullable) should be initialized from
195+
* expected schema and actual data when reading schema-defined data formats.
196+
*/
197+
public enum class NullabilityOptions {
198+
/**
199+
* Use only actual data, set [DataColumn.hasNulls] to true if and only if there are null values in the column.
200+
* On empty dataset use False.
201+
*/
202+
Keeping,
203+
204+
/**
205+
* Set [DataColumn.hasNulls] to expected value. Throw exception if column should be not nullable but there are null values.
206+
*/
207+
Checking,
208+
209+
/**
210+
* Set [DataColumn.hasNulls] to expected value by default. Change False to True if column should be not nullable but there are null values.
211+
*/
212+
Widening
213+
}
214+
215+
public class NullabilityException() : Exception()
216+
217+
/**
218+
* @return if column should be marked nullable for current [NullabilityOptions] value with actual [data] and [expectedNulls] per some schema/signature.
219+
* @throws [NullabilityException] for [NullabilityOptions.Checking] if [expectedNulls] is false and [data] contains nulls.
220+
*/
221+
public fun NullabilityOptions.applyNullability(data: List<Any?>, expectedNulls: Boolean): Boolean {
222+
val hasNulls = data.anyNull()
223+
return when (this) {
224+
NullabilityOptions.Keeping -> hasNulls
225+
NullabilityOptions.Checking -> {
226+
if (!expectedNulls && hasNulls) {
227+
throw NullabilityException()
228+
}
229+
expectedNulls
230+
}
231+
NullabilityOptions.Widening -> {
232+
expectedNulls || hasNulls
233+
}
234+
}
235+
}
236+
192237
public inline fun <reified T> Iterable<T>.toColumn(
193238
name: String = "",
194239
infer: Infer = Infer.None

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 71 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
3737
import org.jetbrains.kotlinx.dataframe.DataColumn
3838
import org.jetbrains.kotlinx.dataframe.DataFrame
3939
import org.jetbrains.kotlinx.dataframe.api.Infer
40+
import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions
41+
import org.jetbrains.kotlinx.dataframe.api.applyNullability
42+
import org.jetbrains.kotlinx.dataframe.api.NullabilityException
4043
import org.jetbrains.kotlinx.dataframe.api.cast
4144
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
4245
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
@@ -108,13 +111,13 @@ internal fun <T> Iterable<DataFrame<T>>.concatKeepingSchema(): DataFrame<T> {
108111
/**
109112
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel]
110113
*/
111-
public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
114+
public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, allocator: RootAllocator = Allocator.ROOT, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame {
112115
ArrowStreamReader(channel, allocator).use { reader ->
113116
val dfs = buildList {
114117
val root = reader.vectorSchemaRoot
115118
val schema = root.schema
116119
while (reader.loadNextBatch()) {
117-
val df = schema.fields.map { f -> readField(root, f) }.toDataFrame()
120+
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
118121
add(df)
119122
}
120123
}
@@ -125,14 +128,14 @@ public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, alloca
125128
/**
126129
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [channel]
127130
*/
128-
public fun DataFrame.Companion.readArrowFeather(channel: SeekableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
131+
public fun DataFrame.Companion.readArrowFeather(channel: SeekableByteChannel, allocator: RootAllocator = Allocator.ROOT, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame {
129132
ArrowFileReader(channel, allocator).use { reader ->
130133
val dfs = buildList {
131134
reader.recordBlocks.forEach { block ->
132135
reader.loadRecordBatch(block)
133136
val root = reader.vectorSchemaRoot
134137
val schema = root.schema
135-
val df = schema.fields.map { f -> readField(root, f) }.toDataFrame()
138+
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
136139
add(df)
137140
}
138141
}
@@ -222,102 +225,115 @@ private fun LargeVarCharVector.values(range: IntRange): List<String?> = range.ma
222225
}
223226
}
224227

225-
private inline fun <reified T> List<T>.withType(nullability: Boolean) = this to typeOf<T>().withNullability(nullability)
226-
227-
private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol {
228-
val range = 0 until root.rowCount
229-
val (list, type) = when (val vector = root.getVector(field)) {
230-
is VarCharVector -> vector.values(range).withType(field.isNullable)
231-
is LargeVarCharVector -> vector.values(range).withType(field.isNullable)
232-
is VarBinaryVector -> vector.values(range).withType(field.isNullable)
233-
is LargeVarBinaryVector -> vector.values(range).withType(field.isNullable)
234-
is BitVector -> vector.values(range).withType(field.isNullable)
235-
is SmallIntVector -> vector.values(range).withType(field.isNullable)
236-
is TinyIntVector -> vector.values(range).withType(field.isNullable)
237-
is UInt1Vector -> vector.values(range).withType(field.isNullable)
238-
is UInt2Vector -> vector.values(range).withType(field.isNullable)
239-
is UInt4Vector -> vector.values(range).withType(field.isNullable)
240-
is UInt8Vector -> vector.values(range).withType(field.isNullable)
241-
is IntVector -> vector.values(range).withType(field.isNullable)
242-
is BigIntVector -> vector.values(range).withType(field.isNullable)
243-
is DecimalVector -> vector.values(range).withType(field.isNullable)
244-
is Decimal256Vector -> vector.values(range).withType(field.isNullable)
245-
is Float8Vector -> vector.values(range).withType(field.isNullable)
246-
is Float4Vector -> vector.values(range).withType(field.isNullable)
247-
is DurationVector -> vector.values(range).withType(field.isNullable)
248-
is DateDayVector -> vector.values(range).withType(field.isNullable)
249-
is DateMilliVector -> vector.values(range).withType(field.isNullable)
250-
is TimeNanoVector -> vector.values(range).withType(field.isNullable)
251-
is TimeMicroVector -> vector.values(range).withType(field.isNullable)
252-
is TimeMilliVector -> vector.values(range).withType(field.isNullable)
253-
is TimeSecVector -> vector.values(range).withType(field.isNullable)
254-
is StructVector -> vector.values(range).withType(field.isNullable)
255-
else -> {
256-
TODO("not fully implemented")
228+
private inline fun <reified T> List<T?>.withTypeNullable(expectedNulls: Boolean, nullabilityOptions: NullabilityOptions): Pair<List<T?>, KType> {
229+
val nullable = nullabilityOptions.applyNullability(this, expectedNulls)
230+
return this to typeOf<T>().withNullability(nullable)
231+
}
232+
233+
private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol {
234+
try {
235+
val range = 0 until root.rowCount
236+
val (list, type) = when (val vector = root.getVector(field)) {
237+
is VarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
238+
is LargeVarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
239+
is VarBinaryVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
240+
is LargeVarBinaryVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
241+
is BitVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
242+
is SmallIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
243+
is TinyIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
244+
is UInt1Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
245+
is UInt2Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
246+
is UInt4Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
247+
is UInt8Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
248+
is IntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
249+
is BigIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
250+
is DecimalVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
251+
is Decimal256Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
252+
is Float8Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
253+
is Float4Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
254+
is DurationVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
255+
is DateDayVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
256+
is DateMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
257+
is TimeNanoVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
258+
is TimeMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
259+
is TimeMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
260+
is TimeSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
261+
is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
262+
else -> {
263+
TODO("not fully implemented")
264+
}
257265
}
266+
return DataColumn.createValueColumn(field.name, list, type, Infer.None)
267+
} catch (unexpectedNull: NullabilityException) {
268+
throw IllegalArgumentException("Column `${field.name}` should be not nullable but has nulls")
258269
}
259-
return DataColumn.createValueColumn(field.name, list, type, Infer.None)
260270
}
261271

262272
// IPC reading block
263273

264274
/**
265275
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [file]
266276
*/
267-
public fun DataFrame.Companion.readArrowIPC(file: File): AnyFrame = Files.newByteChannel(file.toPath()).use { readArrowIPC(it) }
277+
public fun DataFrame.Companion.readArrowIPC(file: File, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
278+
Files.newByteChannel(file.toPath()).use { readArrowIPC(it, nullability = nullability) }
268279

269280
/**
270281
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [byteArray]
271282
*/
272-
public fun DataFrame.Companion.readArrowIPC(byteArray: ByteArray): AnyFrame = SeekableInMemoryByteChannel(byteArray).use { readArrowIPC(it) }
283+
public fun DataFrame.Companion.readArrowIPC(byteArray: ByteArray, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
284+
SeekableInMemoryByteChannel(byteArray).use { readArrowIPC(it, nullability = nullability) }
273285

274286
/**
275287
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [stream]
276288
*/
277-
public fun DataFrame.Companion.readArrowIPC(stream: InputStream): AnyFrame = Channels.newChannel(stream).use { readArrowIPC(it) }
289+
public fun DataFrame.Companion.readArrowIPC(stream: InputStream, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
290+
Channels.newChannel(stream).use { readArrowIPC(it, nullability = nullability) }
278291

279292
/**
280293
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [url]
281294
*/
282-
public fun DataFrame.Companion.readArrowIPC(url: URL): AnyFrame =
295+
public fun DataFrame.Companion.readArrowIPC(url: URL, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
283296
when {
284-
isFile(url) -> readArrowIPC(urlAsFile(url))
285-
isProtocolSupported(url) -> url.openStream().use { readArrowIPC(it) }
297+
isFile(url) -> readArrowIPC(urlAsFile(url), nullability)
298+
isProtocolSupported(url) -> url.openStream().use { readArrowIPC(it, nullability) }
286299
else -> {
287300
throw IllegalArgumentException("Invalid protocol for url $url")
288301
}
289302
}
290303

291-
public fun DataFrame.Companion.readArrowIPC(path: String): AnyFrame = if (isURL(path)) {
292-
readArrowIPC(URL(path))
304+
public fun DataFrame.Companion.readArrowIPC(path: String, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame = if (isURL(path)) {
305+
readArrowIPC(URL(path), nullability)
293306
} else {
294-
readArrowIPC(File(path))
307+
readArrowIPC(File(path), nullability)
295308
}
296309

297310
// Feather reading block
298311

299312
/**
300313
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [file]
301314
*/
302-
public fun DataFrame.Companion.readArrowFeather(file: File): AnyFrame = Files.newByteChannel(file.toPath()).use { readArrowFeather(it) }
315+
public fun DataFrame.Companion.readArrowFeather(file: File, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
316+
Files.newByteChannel(file.toPath()).use { readArrowFeather(it, nullability = nullability) }
303317

304318
/**
305319
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [byteArray]
306320
*/
307-
public fun DataFrame.Companion.readArrowFeather(byteArray: ByteArray): AnyFrame = SeekableInMemoryByteChannel(byteArray).use { readArrowFeather(it) }
321+
public fun DataFrame.Companion.readArrowFeather(byteArray: ByteArray, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
322+
SeekableInMemoryByteChannel(byteArray).use { readArrowFeather(it, nullability = nullability) }
308323

309324
/**
310325
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [stream]
311326
*/
312-
public fun DataFrame.Companion.readArrowFeather(stream: InputStream): AnyFrame = readArrowFeather(stream.readBytes())
327+
public fun DataFrame.Companion.readArrowFeather(stream: InputStream, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
328+
readArrowFeather(stream.readBytes(), nullability)
313329

314330
/**
315331
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [url]
316332
*/
317-
public fun DataFrame.Companion.readArrowFeather(url: URL): AnyFrame =
333+
public fun DataFrame.Companion.readArrowFeather(url: URL, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame =
318334
when {
319-
isFile(url) -> readArrowFeather(urlAsFile(url))
320-
isProtocolSupported(url) -> readArrowFeather(url.readBytes())
335+
isFile(url) -> readArrowFeather(urlAsFile(url), nullability)
336+
isProtocolSupported(url) -> readArrowFeather(url.readBytes(), nullability)
321337
else -> {
322338
throw IllegalArgumentException("Invalid protocol for url $url")
323339
}
@@ -326,8 +342,8 @@ public fun DataFrame.Companion.readArrowFeather(url: URL): AnyFrame =
326342
/**
327343
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [path]
328344
*/
329-
public fun DataFrame.Companion.readArrowFeather(path: String): AnyFrame = if (isURL(path)) {
330-
readArrowFeather(URL(path))
345+
public fun DataFrame.Companion.readArrowFeather(path: String, nullability: NullabilityOptions = NullabilityOptions.Keeping): AnyFrame = if (isURL(path)) {
346+
readArrowFeather(URL(path), nullability)
331347
} else {
332-
readArrowFeather(File(path))
348+
readArrowFeather(File(path), nullability)
333349
}

0 commit comments

Comments
 (0)