Skip to content

Commit 68bf499

Browse files
committed
warningSubscriber
1 parent 90cc6f8 commit 68bf499

File tree

1 file changed

+75
-20
lines changed
  • dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io

1 file changed

+75
-20
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,29 @@ import kotlinx.datetime.TimeZone
44
import kotlinx.datetime.toInstant
55
import kotlinx.datetime.toJavaLocalDate
66
import org.apache.arrow.memory.RootAllocator
7-
import org.apache.arrow.vector.*
7+
import org.apache.arrow.vector.BaseFixedWidthVector
8+
import org.apache.arrow.vector.BaseVariableWidthVector
9+
import org.apache.arrow.vector.FieldVector
10+
import org.apache.arrow.vector.FixedWidthVector
11+
import org.apache.arrow.vector.LargeVarCharVector
12+
import org.apache.arrow.vector.TinyIntVector
13+
import org.apache.arrow.vector.SmallIntVector
14+
import org.apache.arrow.vector.IntVector
15+
import org.apache.arrow.vector.BigIntVector
16+
import org.apache.arrow.vector.BitVector
17+
import org.apache.arrow.vector.DateDayVector
18+
import org.apache.arrow.vector.DateMilliVector
19+
import org.apache.arrow.vector.DecimalVector
20+
import org.apache.arrow.vector.Decimal256Vector
21+
import org.apache.arrow.vector.Float4Vector
22+
import org.apache.arrow.vector.Float8Vector
23+
import org.apache.arrow.vector.TimeMicroVector
24+
import org.apache.arrow.vector.TimeMilliVector
25+
import org.apache.arrow.vector.TimeNanoVector
26+
import org.apache.arrow.vector.TimeSecVector
27+
import org.apache.arrow.vector.VariableWidthVector
28+
import org.apache.arrow.vector.VarCharVector
29+
import org.apache.arrow.vector.VectorSchemaRoot
830
import org.apache.arrow.vector.ipc.ArrowFileWriter
931
import org.apache.arrow.vector.ipc.ArrowStreamWriter
1032
import org.apache.arrow.vector.types.DateUnit
@@ -18,8 +40,22 @@ import org.apache.arrow.vector.util.Text
1840
import org.jetbrains.kotlinx.dataframe.AnyCol
1941
import org.jetbrains.kotlinx.dataframe.AnyFrame
2042
import org.jetbrains.kotlinx.dataframe.DataFrame
21-
import org.jetbrains.kotlinx.dataframe.api.*
43+
import org.jetbrains.kotlinx.dataframe.api.convertTo
44+
import org.jetbrains.kotlinx.dataframe.api.convertToBoolean
45+
import org.jetbrains.kotlinx.dataframe.api.convertToBigDecimal
46+
import org.jetbrains.kotlinx.dataframe.api.convertToDouble
47+
import org.jetbrains.kotlinx.dataframe.api.convertToFloat
48+
import org.jetbrains.kotlinx.dataframe.api.convertToLong
49+
import org.jetbrains.kotlinx.dataframe.api.convertToInt
50+
import org.jetbrains.kotlinx.dataframe.api.convertToLocalDate
51+
import org.jetbrains.kotlinx.dataframe.api.convertToLocalTime
52+
import org.jetbrains.kotlinx.dataframe.api.convertToLocalDateTime
53+
import org.jetbrains.kotlinx.dataframe.api.convertToString
54+
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
2255
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
56+
import org.jetbrains.kotlinx.dataframe.typeClass
57+
import org.slf4j.Logger
58+
import org.slf4j.LoggerFactory
2359
import java.io.ByteArrayOutputStream
2460
import java.io.File
2561
import java.io.FileOutputStream
@@ -31,11 +67,13 @@ import java.time.LocalDateTime
3167
import java.time.LocalTime
3268
import kotlin.reflect.typeOf
3369

70+
private val writeWarningMessage: (String) -> Unit = {message: String -> System.err.println(message)}
71+
3472
/**
3573
* Create Arrow [Schema] matching [this] actual data.
3674
* Columns with not supported types will be interpreted as String
3775
*/
38-
public fun List<AnyCol>.toArrowSchema(): Schema {
76+
public fun List<AnyCol>.toArrowSchema(warningSubscriber: (String) -> Unit = writeWarningMessage): Schema {
3977
val fields = this.map { column ->
4078
when (column.type()) {
4179
typeOf<String?>() -> Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
@@ -71,7 +109,10 @@ public fun List<AnyCol>.toArrowSchema(): Schema {
71109
typeOf<LocalTime?>() -> Field(column.name(), FieldType(true, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
72110
typeOf<LocalTime>() -> Field(column.name(), FieldType(false, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
73111

74-
else -> Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
112+
else -> {
113+
warningSubscriber("Column ${column.name()} has type ${column.typeClass.java.canonicalName}, will be saved as String")
114+
Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
115+
}
75116
}
76117
}
77118
return Schema(fields)
@@ -85,13 +126,22 @@ public fun DataFrame<*>.arrowWriter(): ArrowWriter = this.arrowWriter(this.colum
85126
/**
86127
* Create [ArrowWriter] for [this] DataFrame with explicit [targetSchema]
87128
*/
88-
public fun DataFrame<*>.arrowWriter(targetSchema: Schema, mode: ArrowWriter.Companion.Mode = ArrowWriter.Companion.Mode.STRICT): ArrowWriter = ArrowWriter(this, targetSchema, mode)
129+
public fun DataFrame<*>.arrowWriter(
130+
targetSchema: Schema,
131+
mode: ArrowWriter.Companion.Mode = ArrowWriter.Companion.Mode.STRICT,
132+
warningSubscriber: (String) -> Unit = writeWarningMessage
133+
): ArrowWriter = ArrowWriter(this, targetSchema, mode, warningSubscriber)
89134

90135
/**
91136
* Save [dataFrame] content in Apache Arrow format (can be written to File, ByteArray, OutputStream or raw Channel) with [targetSchema].
92137
* If [dataFrame] content does not match with [targetSchema], behaviour is specified by [mode]
93138
*/
94-
public class ArrowWriter(private val dataFrame: DataFrame<*>, private val targetSchema: Schema, private val mode: Mode): AutoCloseable {
139+
public class ArrowWriter(
140+
private val dataFrame: DataFrame<*>,
141+
private val targetSchema: Schema,
142+
private val mode: Mode,
143+
private val warningSubscriber: (String) -> Unit = writeWarningMessage
144+
): AutoCloseable {
95145

96146
public companion object {
97147
/**
@@ -143,15 +193,15 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
143193
ArrowType.Int(16, true) -> column.convertTo<Short>()
144194
ArrowType.Int(32, true) -> column.convertTo<Int>()
145195
ArrowType.Int(64, true) -> column.convertTo<Long>()
146-
// ArrowType.Int(8, false), ArrowType.Int(16, false), ArrowType.Int(32, false), ArrowType.Int(64, false) ->
196+
// ArrowType.Int(8, false), ArrowType.Int(16, false), ArrowType.Int(32, false), ArrowType.Int(64, false) -> todo
147197
is ArrowType.Decimal -> column.convertToBigDecimal()
148198
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) -> column.convertToFloat()
149199
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) -> column.convertToDouble()
150200
ArrowType.Date(DateUnit.DAY) -> column.convertToLocalDate()
151201
ArrowType.Date(DateUnit.MILLISECOND) -> column.convertToLocalDateTime()
152202
is ArrowType.Time -> column.convertToLocalTime()
153-
// is ArrowType.Duration ->
154-
// is ArrowType.Struct ->
203+
// is ArrowType.Duration -> todo
204+
// is ArrowType.Struct -> todo
155205
else -> {
156206
TODO("Saving ${targetFieldType.javaClass.canonicalName} is not implemented")
157207
}
@@ -162,30 +212,30 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
162212
when (vector) {
163213
is VarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text(value)); value} ?: vector.setNull(i) }
164214
is LargeVarCharVector -> column.convertToString().forEachIndexed { i, value -> value?.let { vector.set(i, Text(value)); value} ?: vector.setNull(i) }
165-
// is VarBinaryVector -> vector.values(range).withType()
166-
// is LargeVarBinaryVector -> vector.values(range).withType()
215+
// is VarBinaryVector -> todo
216+
// is LargeVarBinaryVector -> todo
167217
is BitVector -> column.convertToBoolean().forEachIndexed { i, value -> value?.let { vector.set(i, value.compareTo(false)); value} ?: vector.setNull(i) }
168218
is TinyIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
169219
is SmallIntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
170220
is IntVector -> column.convertToInt().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
171221
is BigIntVector -> column.convertToLong().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
172-
// is UInt1Vector -> vector.values(range).withType()
173-
// is UInt2Vector -> vector.values(range).withType()
174-
// is UInt4Vector -> vector.values(range).withType()
175-
// is UInt8Vector -> vector.values(range).withType()
222+
// is UInt1Vector -> todo
223+
// is UInt2Vector -> todo
224+
// is UInt4Vector -> todo
225+
// is UInt8Vector -> todo
176226
is DecimalVector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
177227
is Decimal256Vector -> column.convertToBigDecimal().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
178228
is Float8Vector -> column.convertToDouble().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
179229
is Float4Vector -> column.convertToFloat().forEachIndexed { i, value -> value?.let { vector.set(i, value); value} ?: vector.setNull(i) }
180230

181231
is DateDayVector -> column.convertToLocalDate().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toJavaLocalDate().toEpochDay()).toInt()); value} ?: vector.setNull(i) }
182232
is DateMilliVector -> column.convertToLocalDateTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toInstant(TimeZone.UTC).toEpochMilliseconds()); value} ?: vector.setNull(i) }
183-
// is DurationVector -> vector.values(range).withType()
233+
// is DurationVector -> todo
184234
is TimeNanoVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay()); value} ?: vector.setNull(i) }
185235
is TimeMicroVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, value.toNanoOfDay() / 1000); value} ?: vector.setNull(i) }
186236
is TimeMilliVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000).toInt()); value} ?: vector.setNull(i) }
187237
is TimeSecVector -> column.convertToLocalTime().forEachIndexed { i, value -> value?.let { vector.set(i, (value.toNanoOfDay() / 1000 / 1000 / 1000).toInt()); value} ?: vector.setNull(i) }
188-
// is StructVector -> vector.values(range).withType()
238+
// is StructVector -> todo
189239
else -> {
190240
TODO("Saving to ${vector.javaClass.canonicalName} is not implemented")
191241
}
@@ -206,7 +256,8 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
206256
throw e
207257
} else {
208258
// If strictType is not enabled, use original data with its type. Target nullable is saved at this step.
209-
val actualType = listOf(column!!).toArrowSchema().fields.first().fieldType.type
259+
warningSubscriber(e.message)
260+
val actualType = listOf(column!!).toArrowSchema(warningSubscriber).fields.first().fieldType.type
210261
val actualField = Field(field.name, FieldType(field.isNullable, actualType, field.fieldType.dictionary), field.children)
211262
column to actualField
212263
}
@@ -215,6 +266,7 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
215266
if (strictNullable) {
216267
throw Exception("${actualField.name} column contains nulls but should be not nullable")
217268
} else {
269+
warningSubscriber("${actualField.name} column contains nulls but expected not nullable")
218270
Field(actualField.name, FieldType(true, actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
219271
}
220272
} else {
@@ -231,7 +283,7 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
231283
return vector
232284
}
233285

234-
private fun List<AnyCol>.toVectors(): List<FieldVector> = this.toArrowSchema().fields.mapIndexed { i, field ->
286+
private fun List<AnyCol>.toVectors(): List<FieldVector> = this.toArrowSchema(warningSubscriber).fields.mapIndexed { i, field ->
235287
allocateVectorAndInfill(field, this[i], true, true)
236288
}
237289

@@ -246,6 +298,7 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
246298
if (mode.restrictNarrowing) {
247299
throw Exception("${field.name} column is not presented")
248300
} else {
301+
warningSubscriber("${field.name} column is not presented")
249302
continue
250303
}
251304
}
@@ -255,9 +308,11 @@ public class ArrowWriter(private val dataFrame: DataFrame<*>, private val target
255308
}
256309
val vectors = ArrayList<FieldVector>()
257310
vectors.addAll(mainVectors.values)
311+
val otherVectors = dataFrame.columns().filter { column -> !mainVectors.containsKey(column.name()) }.toVectors()
258312
if (!mode.restrictWidening) {
259-
val otherVectors = dataFrame.columns().filter { column -> !mainVectors.containsKey(column.name()) }.toVectors()
260313
vectors.addAll(otherVectors)
314+
} else {
315+
otherVectors.forEach { warningSubscriber("${it.name} column is not described in target schema and was ignored") }
261316
}
262317
return VectorSchemaRoot(vectors)
263318
}

0 commit comments

Comments
 (0)