Skip to content

Commit 9e2a1ab

Browse files
committed
Extract AnyCol.toArrowField function
1 parent 6c439fa commit 9e2a1ab

File tree

3 files changed

+89
-35
lines changed

3 files changed

+89
-35
lines changed

dataframe-arrow/build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,7 @@ kotlinPublications {
3535
kotlin {
3636
explicitApi()
3737
}
38+
39+
tasks.test {
40+
jvmArgs = listOf("--add-opens", "java.base/java.nio=ALL-UNNAMED")
41+
}

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -77,43 +77,89 @@ private val logger = LoggerFactory.getLogger(ArrowWriter::class.java)
7777

7878
public val logWarningMessage: (String) -> Unit = { message: String -> logger.debug(message) }
7979

80+
public fun AnyCol.toArrowField(warningSubscriber: (String) -> Unit = ignoreWarningMessage): Field {
81+
val column = this
82+
val columnType = column.type()
83+
val nullable = columnType.isMarkedNullable
84+
return when {
85+
columnType.isSubtypeOf(typeOf<String?>()) -> Field(
86+
column.name(),
87+
FieldType(nullable, ArrowType.Utf8(), null),
88+
emptyList()
89+
)
90+
91+
columnType.isSubtypeOf(typeOf<Boolean?>()) -> Field(
92+
column.name(),
93+
FieldType(nullable, ArrowType.Bool(), null),
94+
emptyList()
95+
)
96+
97+
columnType.isSubtypeOf(typeOf<Byte?>()) -> Field(
98+
column.name(),
99+
FieldType(nullable, ArrowType.Int(8, true), null),
100+
emptyList()
101+
)
102+
103+
columnType.isSubtypeOf(typeOf<Short?>()) -> Field(
104+
column.name(),
105+
FieldType(nullable, ArrowType.Int(16, true), null),
106+
emptyList()
107+
)
108+
109+
columnType.isSubtypeOf(typeOf<Int?>()) -> Field(
110+
column.name(),
111+
FieldType(nullable, ArrowType.Int(32, true), null),
112+
emptyList()
113+
)
114+
115+
columnType.isSubtypeOf(typeOf<Long?>()) -> Field(
116+
column.name(),
117+
FieldType(nullable, ArrowType.Int(64, true), null),
118+
emptyList()
119+
)
120+
121+
columnType.isSubtypeOf(typeOf<Float?>()) -> Field(
122+
column.name(),
123+
FieldType(nullable, ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null),
124+
emptyList()
125+
)
126+
127+
columnType.isSubtypeOf(typeOf<Double?>()) -> Field(
128+
column.name(),
129+
FieldType(nullable, ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null),
130+
emptyList()
131+
)
132+
133+
columnType.isSubtypeOf(typeOf<LocalDate?>()) || columnType.isSubtypeOf(typeOf<kotlinx.datetime.LocalDate?>()) -> Field(
134+
column.name(),
135+
FieldType(nullable, ArrowType.Date(DateUnit.DAY), null),
136+
emptyList()
137+
)
138+
139+
columnType.isSubtypeOf(typeOf<LocalDateTime?>()) || columnType.isSubtypeOf(typeOf<kotlinx.datetime.LocalDateTime?>()) -> Field(
140+
column.name(),
141+
FieldType(nullable, ArrowType.Date(DateUnit.MILLISECOND), null),
142+
emptyList()
143+
)
144+
145+
columnType.isSubtypeOf(typeOf<LocalTime?>()) -> Field(
146+
column.name(),
147+
FieldType(nullable, ArrowType.Time(TimeUnit.NANOSECOND, 64), null),
148+
emptyList()
149+
)
150+
151+
else -> {
152+
warningSubscriber("Column ${column.name()} has type ${column.typeClass.java.canonicalName}, will be saved as String")
153+
Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
154+
}
155+
}
156+
}
80157
/**
81158
* Create Arrow [Schema] matching [this] actual data.
82159
* Columns with not supported types will be interpreted as String
83160
*/
84161
public fun List<AnyCol>.toArrowSchema(warningSubscriber: (String) -> Unit = ignoreWarningMessage): Schema {
85-
val fields = this.map { column ->
86-
val columnType = column.type()
87-
val nullable = columnType.isMarkedNullable
88-
when {
89-
columnType.isSubtypeOf(typeOf<String?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Utf8(), null), emptyList())
90-
91-
columnType.isSubtypeOf(typeOf<Boolean?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Bool(), null), emptyList())
92-
93-
columnType.isSubtypeOf(typeOf<Byte?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(8, true), null), emptyList())
94-
95-
columnType.isSubtypeOf(typeOf<Short?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(16, true), null), emptyList())
96-
97-
columnType.isSubtypeOf(typeOf<Int?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(32, true), null), emptyList())
98-
99-
columnType.isSubtypeOf(typeOf<Long?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(64, true), null), emptyList())
100-
101-
columnType.isSubtypeOf(typeOf<Float?>()) -> Field(column.name(), FieldType(nullable, ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null), emptyList())
102-
103-
columnType.isSubtypeOf(typeOf<Double?>()) -> Field(column.name(), FieldType(nullable, ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null), emptyList())
104-
105-
columnType.isSubtypeOf(typeOf<LocalDate?>()) || columnType.isSubtypeOf(typeOf<kotlinx.datetime.LocalDate?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Date(DateUnit.DAY), null), emptyList())
106-
107-
columnType.isSubtypeOf(typeOf<LocalDateTime?>()) || columnType.isSubtypeOf(typeOf<kotlinx.datetime.LocalDateTime?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Date(DateUnit.MILLISECOND), null), emptyList())
108-
109-
columnType.isSubtypeOf(typeOf<LocalTime?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
110-
111-
else -> {
112-
warningSubscriber("Column ${column.name()} has type ${column.typeClass.java.canonicalName}, will be saved as String")
113-
Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
114-
}
115-
}
116-
}
162+
val fields = this.map { it.toArrowField(warningSubscriber) }
117163
return Schema(fields)
118164
}
119165

@@ -292,10 +338,10 @@ public class ArrowWriter(
292338
return vector
293339
}
294340

295-
private fun List<AnyCol>.toVectors(): List<FieldVector> = this.toArrowSchema(warningSubscriber).fields.mapIndexed { i, field ->
296-
allocateVectorAndInfill(field, this[i], true, true)
341+
private fun List<AnyCol>.toVectors(): List<FieldVector> = this.map {
342+
val field = it.toArrowField(warningSubscriber)
343+
allocateVectorAndInfill(field, it, true, true)
297344
}
298-
299345
/**
300346
* Create Arrow VectorSchemaRoot with [dataFrame] content cast to [targetSchema] according to the [mode].
301347
*/

tests/build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,7 @@ kotlinter {
100100
"filename"
101101
)
102102
}
103+
104+
tasks.test {
105+
jvmArgs = listOf("--add-opens", "java.base/java.nio=ALL-UNNAMED")
106+
}

0 commit comments

Comments
 (0)