@@ -37,6 +37,9 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
37
37
import org.jetbrains.kotlinx.dataframe.DataColumn
38
38
import org.jetbrains.kotlinx.dataframe.DataFrame
39
39
import org.jetbrains.kotlinx.dataframe.api.Infer
40
+ import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions
41
+ import org.jetbrains.kotlinx.dataframe.api.applyNullability
42
+ import org.jetbrains.kotlinx.dataframe.api.NullabilityException
40
43
import org.jetbrains.kotlinx.dataframe.api.cast
41
44
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
42
45
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
@@ -108,13 +111,13 @@ internal fun <T> Iterable<DataFrame<T>>.concatKeepingSchema(): DataFrame<T> {
108
111
/* *
109
112
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel]
110
113
*/
111
- public fun DataFrame.Companion.readArrowIPC (channel : ReadableByteChannel , allocator : RootAllocator = Allocator .ROOT ): AnyFrame {
114
+ public fun DataFrame.Companion.readArrowIPC (channel : ReadableByteChannel , allocator : RootAllocator = Allocator .ROOT , nullability : NullabilityOptions = NullabilityOptions . Keeping ): AnyFrame {
112
115
ArrowStreamReader (channel, allocator).use { reader ->
113
116
val dfs = buildList {
114
117
val root = reader.vectorSchemaRoot
115
118
val schema = root.schema
116
119
while (reader.loadNextBatch()) {
117
- val df = schema.fields.map { f -> readField(root, f) }.toDataFrame()
120
+ val df = schema.fields.map { f -> readField(root, f, nullability ) }.toDataFrame()
118
121
add(df)
119
122
}
120
123
}
@@ -125,14 +128,14 @@ public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, alloca
125
128
/* *
126
129
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [channel]
127
130
*/
128
- public fun DataFrame.Companion.readArrowFeather (channel : SeekableByteChannel , allocator : RootAllocator = Allocator .ROOT ): AnyFrame {
131
+ public fun DataFrame.Companion.readArrowFeather (channel : SeekableByteChannel , allocator : RootAllocator = Allocator .ROOT , nullability : NullabilityOptions = NullabilityOptions . Keeping ): AnyFrame {
129
132
ArrowFileReader (channel, allocator).use { reader ->
130
133
val dfs = buildList {
131
134
reader.recordBlocks.forEach { block ->
132
135
reader.loadRecordBatch(block)
133
136
val root = reader.vectorSchemaRoot
134
137
val schema = root.schema
135
- val df = schema.fields.map { f -> readField(root, f) }.toDataFrame()
138
+ val df = schema.fields.map { f -> readField(root, f, nullability ) }.toDataFrame()
136
139
add(df)
137
140
}
138
141
}
@@ -222,102 +225,115 @@ private fun LargeVarCharVector.values(range: IntRange): List<String?> = range.ma
222
225
}
223
226
}
224
227
225
- private inline fun <reified T > List<T>.withType (nullability : Boolean ) = this to typeOf<T >().withNullability(nullability)
226
-
227
- private fun readField (root : VectorSchemaRoot , field : Field ): AnyBaseCol {
228
- val range = 0 until root.rowCount
229
- val (list, type) = when (val vector = root.getVector(field)) {
230
- is VarCharVector -> vector.values(range).withType(field.isNullable)
231
- is LargeVarCharVector -> vector.values(range).withType(field.isNullable)
232
- is VarBinaryVector -> vector.values(range).withType(field.isNullable)
233
- is LargeVarBinaryVector -> vector.values(range).withType(field.isNullable)
234
- is BitVector -> vector.values(range).withType(field.isNullable)
235
- is SmallIntVector -> vector.values(range).withType(field.isNullable)
236
- is TinyIntVector -> vector.values(range).withType(field.isNullable)
237
- is UInt1Vector -> vector.values(range).withType(field.isNullable)
238
- is UInt2Vector -> vector.values(range).withType(field.isNullable)
239
- is UInt4Vector -> vector.values(range).withType(field.isNullable)
240
- is UInt8Vector -> vector.values(range).withType(field.isNullable)
241
- is IntVector -> vector.values(range).withType(field.isNullable)
242
- is BigIntVector -> vector.values(range).withType(field.isNullable)
243
- is DecimalVector -> vector.values(range).withType(field.isNullable)
244
- is Decimal256Vector -> vector.values(range).withType(field.isNullable)
245
- is Float8Vector -> vector.values(range).withType(field.isNullable)
246
- is Float4Vector -> vector.values(range).withType(field.isNullable)
247
- is DurationVector -> vector.values(range).withType(field.isNullable)
248
- is DateDayVector -> vector.values(range).withType(field.isNullable)
249
- is DateMilliVector -> vector.values(range).withType(field.isNullable)
250
- is TimeNanoVector -> vector.values(range).withType(field.isNullable)
251
- is TimeMicroVector -> vector.values(range).withType(field.isNullable)
252
- is TimeMilliVector -> vector.values(range).withType(field.isNullable)
253
- is TimeSecVector -> vector.values(range).withType(field.isNullable)
254
- is StructVector -> vector.values(range).withType(field.isNullable)
255
- else -> {
256
- TODO (" not fully implemented" )
228
+ private inline fun <reified T > List<T?>.withTypeNullable (expectedNulls : Boolean , nullabilityOptions : NullabilityOptions ): Pair <List <T ?>, KType> {
229
+ val nullable = nullabilityOptions.applyNullability(this , expectedNulls)
230
+ return this to typeOf<T >().withNullability(nullable)
231
+ }
232
+
233
+ private fun readField (root : VectorSchemaRoot , field : Field , nullability : NullabilityOptions ): AnyBaseCol {
234
+ try {
235
+ val range = 0 until root.rowCount
236
+ val (list, type) = when (val vector = root.getVector(field)) {
237
+ is VarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
238
+ is LargeVarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
239
+ is VarBinaryVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
240
+ is LargeVarBinaryVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
241
+ is BitVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
242
+ is SmallIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
243
+ is TinyIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
244
+ is UInt1Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
245
+ is UInt2Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
246
+ is UInt4Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
247
+ is UInt8Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
248
+ is IntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
249
+ is BigIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
250
+ is DecimalVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
251
+ is Decimal256Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
252
+ is Float8Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
253
+ is Float4Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
254
+ is DurationVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
255
+ is DateDayVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
256
+ is DateMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
257
+ is TimeNanoVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
258
+ is TimeMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
259
+ is TimeMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
260
+ is TimeSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
261
+ is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
262
+ else -> {
263
+ TODO (" not fully implemented" )
264
+ }
257
265
}
266
+ return DataColumn .createValueColumn(field.name, list, type, Infer .None )
267
+ } catch (unexpectedNull: NullabilityException ) {
268
+ throw IllegalArgumentException (" Column `${field.name} ` should be not nullable but has nulls" )
258
269
}
259
- return DataColumn .createValueColumn(field.name, list, type, Infer .None )
260
270
}
261
271
262
272
// IPC reading block
263
273
264
274
/* *
265
275
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [file]
266
276
*/
267
- public fun DataFrame.Companion.readArrowIPC (file : File ): AnyFrame = Files .newByteChannel(file.toPath()).use { readArrowIPC(it) }
277
+ public fun DataFrame.Companion.readArrowIPC (file : File , nullability : NullabilityOptions = NullabilityOptions .Keeping ): AnyFrame =
278
+ Files .newByteChannel(file.toPath()).use { readArrowIPC(it, nullability = nullability) }
268
279
269
280
/* *
270
281
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [byteArray]
271
282
*/
272
- public fun DataFrame.Companion.readArrowIPC (byteArray : ByteArray ): AnyFrame = SeekableInMemoryByteChannel (byteArray).use { readArrowIPC(it) }
283
+ public fun DataFrame.Companion.readArrowIPC (byteArray : ByteArray , nullability : NullabilityOptions = NullabilityOptions .Keeping ): AnyFrame =
284
+ SeekableInMemoryByteChannel (byteArray).use { readArrowIPC(it, nullability = nullability) }
273
285
274
286
/* *
275
287
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [stream]
276
288
*/
277
- public fun DataFrame.Companion.readArrowIPC (stream : InputStream ): AnyFrame = Channels .newChannel(stream).use { readArrowIPC(it) }
289
+ public fun DataFrame.Companion.readArrowIPC (stream : InputStream , nullability : NullabilityOptions = NullabilityOptions .Keeping ): AnyFrame =
290
+ Channels .newChannel(stream).use { readArrowIPC(it, nullability = nullability) }
278
291
279
292
/* *
280
293
* Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [url]
281
294
*/
282
- public fun DataFrame.Companion.readArrowIPC (url : URL ): AnyFrame =
295
+ public fun DataFrame.Companion.readArrowIPC (url : URL , nullability : NullabilityOptions = NullabilityOptions . Keeping ): AnyFrame =
283
296
when {
284
- isFile(url) -> readArrowIPC(urlAsFile(url))
285
- isProtocolSupported(url) -> url.openStream().use { readArrowIPC(it) }
297
+ isFile(url) -> readArrowIPC(urlAsFile(url), nullability )
298
+ isProtocolSupported(url) -> url.openStream().use { readArrowIPC(it, nullability ) }
286
299
else -> {
287
300
throw IllegalArgumentException (" Invalid protocol for url $url " )
288
301
}
289
302
}
290
303
291
- public fun DataFrame.Companion.readArrowIPC (path : String ): AnyFrame = if (isURL(path)) {
292
- readArrowIPC(URL (path))
304
+ public fun DataFrame.Companion.readArrowIPC (path : String , nullability : NullabilityOptions = NullabilityOptions . Keeping ): AnyFrame = if (isURL(path)) {
305
+ readArrowIPC(URL (path), nullability )
293
306
} else {
294
- readArrowIPC(File (path))
307
+ readArrowIPC(File (path), nullability )
295
308
}
296
309
297
310
// Feather reading block
298
311
299
312
/* *
300
313
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [file]
301
314
*/
302
- public fun DataFrame.Companion.readArrowFeather (file : File ): AnyFrame = Files .newByteChannel(file.toPath()).use { readArrowFeather(it) }
315
+ public fun DataFrame.Companion.readArrowFeather (file : File , nullability : NullabilityOptions = NullabilityOptions .Keeping ): AnyFrame =
316
+ Files .newByteChannel(file.toPath()).use { readArrowFeather(it, nullability = nullability) }
303
317
304
318
/* *
305
319
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [byteArray]
306
320
*/
307
- public fun DataFrame.Companion.readArrowFeather (byteArray : ByteArray ): AnyFrame = SeekableInMemoryByteChannel (byteArray).use { readArrowFeather(it) }
321
+ public fun DataFrame.Companion.readArrowFeather (byteArray : ByteArray , nullability : NullabilityOptions = NullabilityOptions .Keeping ): AnyFrame =
322
+ SeekableInMemoryByteChannel (byteArray).use { readArrowFeather(it, nullability = nullability) }
308
323
309
324
/* *
310
325
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [stream]
311
326
*/
312
- public fun DataFrame.Companion.readArrowFeather (stream : InputStream ): AnyFrame = readArrowFeather(stream.readBytes())
327
+ public fun DataFrame.Companion.readArrowFeather (stream : InputStream , nullability : NullabilityOptions = NullabilityOptions .Keeping ): AnyFrame =
328
+ readArrowFeather(stream.readBytes(), nullability)
313
329
314
330
/* *
315
331
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [url]
316
332
*/
317
- public fun DataFrame.Companion.readArrowFeather (url : URL ): AnyFrame =
333
+ public fun DataFrame.Companion.readArrowFeather (url : URL , nullability : NullabilityOptions = NullabilityOptions . Keeping ): AnyFrame =
318
334
when {
319
- isFile(url) -> readArrowFeather(urlAsFile(url))
320
- isProtocolSupported(url) -> readArrowFeather(url.readBytes())
335
+ isFile(url) -> readArrowFeather(urlAsFile(url), nullability )
336
+ isProtocolSupported(url) -> readArrowFeather(url.readBytes(), nullability )
321
337
else -> {
322
338
throw IllegalArgumentException (" Invalid protocol for url $url " )
323
339
}
@@ -326,8 +342,8 @@ public fun DataFrame.Companion.readArrowFeather(url: URL): AnyFrame =
326
342
/* *
327
343
* Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [path]
328
344
*/
329
- public fun DataFrame.Companion.readArrowFeather (path : String ): AnyFrame = if (isURL(path)) {
330
- readArrowFeather(URL (path))
345
+ public fun DataFrame.Companion.readArrowFeather (path : String , nullability : NullabilityOptions = NullabilityOptions . Keeping ): AnyFrame = if (isURL(path)) {
346
+ readArrowFeather(URL (path), nullability )
331
347
} else {
332
- readArrowFeather(File (path))
348
+ readArrowFeather(File (path), nullability )
333
349
}
0 commit comments