Skip to content

Commit aee6def

Browse files
committed
Implement strictNullable
1 parent 5ae26a7 commit aee6def

File tree

2 files changed

+109
-60
lines changed

2 files changed

+109
-60
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 109 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import kotlinx.datetime.toInstant
55
import kotlinx.datetime.toJavaLocalDate
66
import org.apache.arrow.memory.RootAllocator
77
import org.apache.arrow.vector.*
8+
import org.apache.arrow.vector.ipc.ArrowFileWriter
9+
import org.apache.arrow.vector.ipc.ArrowStreamWriter
810
import org.apache.arrow.vector.types.DateUnit
911
import org.apache.arrow.vector.types.FloatingPointPrecision
1012
import org.apache.arrow.vector.types.TimeUnit
@@ -16,6 +18,9 @@ import org.apache.arrow.vector.util.Text
1618
import org.jetbrains.kotlinx.dataframe.AnyCol
1719
import org.jetbrains.kotlinx.dataframe.DataFrame
1820
import org.jetbrains.kotlinx.dataframe.api.*
21+
import java.io.ByteArrayOutputStream
22+
import java.nio.channels.Channels
23+
import java.nio.channels.WritableByteChannel
1924
import java.time.LocalDate
2025
import java.time.LocalDateTime
2126
import java.time.LocalTime
@@ -63,6 +68,10 @@ public fun List<AnyCol>.toArrowSchema(): Schema {
6368
return Schema(fields)
6469
}
6570

71+
public fun DataFrame<*>.arrowWriter(): ArrowWriter = ArrowWriter(this, this.columns().toArrowSchema())
72+
73+
public fun DataFrame<*>.arrowWriter(targetSchema: Schema): ArrowWriter = ArrowWriter(this, targetSchema)
74+
6675
/**
6776
* Save [dataFrame] content in Apache Arrow format (can be written to File, ByteArray or stream) with [targetSchema].
6877
*
@@ -124,21 +133,34 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
124133

125134
}
126135

127-
private fun List<AnyCol>.toVectors(): List<FieldVector> {
128-
val actualSchema = this.toArrowSchema()
129-
val vectors = ArrayList<FieldVector>()
130-
for ((i, field) in actualSchema.fields.withIndex()) {
131-
val column = this[i]
132-
val vector = field.createVector(allocator)!!
133-
allocateVector(vector, dataFrame.rowsCount())
136+
private fun allocateVectorAndInfill(field: Field, column: AnyCol?, strictType: Boolean, strictNullable: Boolean): FieldVector {
137+
val containNulls = (column == null || column.hasNulls())
138+
val vector = if (!field.isNullable && containNulls) {
139+
if (strictNullable) {
140+
throw Exception("${field.name} column contains nulls but should be not nullable")
141+
} else {
142+
Field(field.name, FieldType(true, field.fieldType.type, field.fieldType.dictionary), field.children).createVector(allocator)!!
143+
}
144+
} else {
145+
field.createVector(allocator)!!
146+
}
147+
148+
allocateVector(vector, dataFrame.rowsCount())
149+
if (column == null) {
150+
check(field.isNullable)
151+
infillWithNulls(vector, dataFrame.rowsCount())
152+
} else {
134153
infillVector(vector, column)
135-
vectors.add(vector)
136154
}
137-
return vectors
155+
return vector
138156
}
139157

140-
/**
141-
* Create Arrow VectorSchemaRoot with [dataFrame] content casted to [targetSchema].
158+
private fun List<AnyCol>.toVectors(): List<FieldVector> = this.toArrowSchema().fields.mapIndexed { i, field ->
159+
allocateVectorAndInfill(field, this[i], true, true)
160+
}
161+
162+
/**
163+
* Create Arrow VectorSchemaRoot with [dataFrame] content cast to [targetSchema].
142164
* If [restrictWidening] is true, [dataFrame] columns not described in [targetSchema] would not be saved (otherwise, would be saved as is).
143165
* If [restrictNarrowing] is true, [targetSchema] fields that are not nullable and do not exist in [dataFrame] will produce exception (otherwise, would not be saved).
144166
* If [strictType] is true, [dataFrame] columns described in [targetSchema] with non-compatible type will produce exception (otherwise, would be saved as is).
@@ -153,7 +175,6 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
153175
val mainVectors = LinkedHashMap<String, FieldVector>()
154176
for (field in targetSchema.fields) {
155177
val column = dataFrame.getColumnOrNull(field.name)
156-
val vector = field.createVector(allocator)!!
157178
if (column == null && !field.isNullable) {
158179
if (restrictNarrowing) {
159180
throw Exception("${field.name} column is not presented")
@@ -162,13 +183,7 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
162183
}
163184
}
164185

165-
allocateVector(vector, dataFrame.rowsCount())
166-
if (column == null) {
167-
check(field.isNullable)
168-
infillWithNulls(vector, dataFrame.rowsCount())
169-
} else {
170-
infillVector(vector, column)
171-
}
186+
val vector = allocateVectorAndInfill(field, column, strictType, strictNullable)
172187
mainVectors[field.name] = vector
173188
}
174189
val vectors = ArrayList<FieldVector>()
@@ -180,7 +195,82 @@ public class ArrowWriter(public val dataFrame: DataFrame<*>, public val targetSc
180195
return VectorSchemaRoot(vectors)
181196
}
182197

198+
public fun featherToChannel(channel: WritableByteChannel) {
199+
allocateVectorSchemaRoot(false, false, false, false).use { vectorSchemaRoot ->
200+
ArrowFileWriter(vectorSchemaRoot, null, channel).use { writer ->
201+
writer.writeBatch();
202+
}
203+
}
204+
}
205+
206+
public fun ipcToChannel(channel: WritableByteChannel) {
207+
allocateVectorSchemaRoot(false, false, false, false).use { vectorSchemaRoot ->
208+
ArrowStreamWriter(vectorSchemaRoot, null, channel).use { writer ->
209+
writer.writeBatch();
210+
}
211+
}
212+
}
213+
214+
public fun featherToByteArray(): ByteArray {
215+
ByteArrayOutputStream().use { byteArrayStream ->
216+
Channels.newChannel(byteArrayStream).use { channel ->
217+
featherToChannel(channel)
218+
return byteArrayStream.toByteArray()
219+
}
220+
}
221+
}
222+
223+
public fun iptToByteArray(): ByteArray {
224+
ByteArrayOutputStream().use { byteArrayStream ->
225+
Channels.newChannel(byteArrayStream).use { channel ->
226+
ipcToChannel(channel)
227+
return byteArrayStream.toByteArray()
228+
}
229+
}
230+
}
231+
183232
override fun close() {
184233
allocator.close()
185234
}
186235
}
236+
//
237+
//// IPC saving block
238+
//
239+
///**
240+
// * Save data to [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format), write to new or existing [file].
241+
// * If file exists, it can be recreated or expanded.
242+
// */
243+
//public fun AnyFrame.writeArrowIPC(file: File, append: Boolean = true) {
244+
//
245+
//}
246+
//
247+
///**
248+
// * Save data to [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format), write to [ByteArray]
249+
// */
250+
//public fun AnyFrame.writeArrowIPCToByteArray() {
251+
//
252+
//}
253+
//
254+
//// Feather saving block
255+
//
256+
///**
257+
// * Save data to [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files), write to new or existing [file].
258+
// * If file exists, it would be recreated.
259+
// */
260+
//public fun AnyFrame.writeArrowFeather(file: File) {
261+
//
262+
//}
263+
//
264+
///**
265+
// * Save data to [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files), write to [ByteArray]
266+
// */
267+
//public fun DataFrame.Companion.writeArrowFeatherToByteArray(): ByteArray {
268+
//
269+
//}
270+
//
271+
///**
272+
// * Write [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) from existing [stream]
273+
// */
274+
//public fun DataFrame.Companion.writeArrowFeather(stream: OutputStream) {
275+
//
276+
//}

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -400,44 +400,3 @@ public fun DataFrame.Companion.readArrowFeather(
400400
} else {
401401
readArrowFeather(File(path), nullability)
402402
}
403-
404-
//// IPC saving block
405-
//
406-
///**
407-
// * Save data to [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format), write to new or existing [file].
408-
// * If file exists, it can be recreated or expanded.
409-
// */
410-
//public fun AnyFrame.writeArrowIPC(file: File, append: Boolean = true) {
411-
//
412-
//}
413-
//
414-
///**
415-
// * Save data to [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format), write to [ByteArray]
416-
// */
417-
//public fun AnyFrame.writeArrowIPCToByteArray() {
418-
//
419-
//}
420-
//
421-
//// Feather saving block
422-
//
423-
///**
424-
// * Save data to [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files), write to new or existing [file].
425-
// * If file exists, it would be recreated.
426-
// */
427-
//public fun AnyFrame.writeArrowFeather(file: File) {
428-
//
429-
//}
430-
//
431-
///**
432-
// * Save data to [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files), write to [ByteArray]
433-
// */
434-
//public fun DataFrame.Companion.writeArrowFeatherToByteArray(): ByteArray {
435-
//
436-
//}
437-
//
438-
///**
439-
// * Write [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) from existing [stream]
440-
// */
441-
//public fun DataFrame.Companion.writeArrowFeather(stream: OutputStream) {
442-
//
443-
//}

0 commit comments

Comments
 (0)