Skip to content

Commit ec8ee5c

Browse files
committed
Use isSubtypeOf in Arrow writing
1 parent f028ebe commit ec8ee5c

File tree

2 files changed

+48
-37
lines changed

2 files changed

+48
-37
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import java.nio.channels.WritableByteChannel
6565
import java.time.LocalDate
6666
import java.time.LocalDateTime
6767
import java.time.LocalTime
68+
import kotlin.reflect.full.isSubtypeOf
6869
import kotlin.reflect.typeOf
6970

7071
private val writeWarningMessage: (String) -> Unit = {message: String -> System.err.println(message)}
@@ -75,39 +76,30 @@ private val writeWarningMessage: (String) -> Unit = {message: String -> System.e
7576
*/
7677
public fun List<AnyCol>.toArrowSchema(warningSubscriber: (String) -> Unit = writeWarningMessage): Schema {
7778
val fields = this.map { column ->
78-
when (column.type()) {
79-
typeOf<String?>() -> Field(column.name(), FieldType(true, ArrowType.Utf8(), null), emptyList())
80-
typeOf<String>() -> Field(column.name(), FieldType(false, ArrowType.Utf8(), null), emptyList())
79+
val columnType = column.type()
80+
val nullable = columnType.isMarkedNullable
81+
when {
82+
columnType.isSubtypeOf(typeOf<String?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Utf8(), null), emptyList())
8183

82-
typeOf<Boolean?>() -> Field(column.name(), FieldType(true, ArrowType.Bool(), null), emptyList())
83-
typeOf<Boolean>() -> Field(column.name(), FieldType(false, ArrowType.Bool(), null), emptyList())
84+
columnType.isSubtypeOf(typeOf<Boolean?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Bool(), null), emptyList())
8485

85-
typeOf<Byte?>() -> Field(column.name(), FieldType(true, ArrowType.Int(8, true), null), emptyList())
86-
typeOf<Byte>() -> Field(column.name(), FieldType(false, ArrowType.Int(8, true), null), emptyList())
86+
columnType.isSubtypeOf(typeOf<Byte?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(8, true), null), emptyList())
8787

88-
typeOf<Short?>() -> Field(column.name(), FieldType(true, ArrowType.Int(16, true), null), emptyList())
89-
typeOf<Short>() -> Field(column.name(), FieldType(false, ArrowType.Int(16, true), null), emptyList())
88+
columnType.isSubtypeOf(typeOf<Short?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(16, true), null), emptyList())
9089

91-
typeOf<Int?>() -> Field(column.name(), FieldType(true, ArrowType.Int(32, true), null), emptyList())
92-
typeOf<Int>() -> Field(column.name(), FieldType(false, ArrowType.Int(32, true), null), emptyList())
90+
columnType.isSubtypeOf(typeOf<Int?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(32, true), null), emptyList())
9391

94-
typeOf<Long?>() -> Field(column.name(), FieldType(true, ArrowType.Int(64, true), null), emptyList())
95-
typeOf<Long>() -> Field(column.name(), FieldType(false, ArrowType.Int(64, true), null), emptyList())
92+
columnType.isSubtypeOf(typeOf<Long?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Int(64, true), null), emptyList())
9693

97-
typeOf<Float?>() -> Field(column.name(), FieldType(true, ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null), emptyList())
98-
typeOf<Float>() -> Field(column.name(), FieldType(false, ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null), emptyList())
94+
columnType.isSubtypeOf(typeOf<Float?>()) -> Field(column.name(), FieldType(nullable, ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null), emptyList())
9995

100-
typeOf<Double?>() -> Field(column.name(), FieldType(true, ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null), emptyList())
101-
typeOf<Double>() -> Field(column.name(), FieldType(false, ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null), emptyList())
96+
columnType.isSubtypeOf(typeOf<Double?>()) -> Field(column.name(), FieldType(nullable, ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null), emptyList())
10297

103-
typeOf<LocalDate?>(), typeOf<kotlinx.datetime.LocalDate?>() -> Field(column.name(), FieldType(true, ArrowType.Date(DateUnit.DAY), null), emptyList())
104-
typeOf<LocalDate>(), typeOf<kotlinx.datetime.LocalDate>() -> Field(column.name(), FieldType(false, ArrowType.Date(DateUnit.DAY), null), emptyList())
98+
columnType.isSubtypeOf(typeOf<LocalDate?>()) || columnType.isSubtypeOf(typeOf<kotlinx.datetime.LocalDate?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Date(DateUnit.DAY), null), emptyList())
10599

106-
typeOf<LocalDateTime?>(), typeOf<kotlinx.datetime.LocalDateTime?>() -> Field(column.name(), FieldType(true, ArrowType.Date(DateUnit.MILLISECOND), null), emptyList())
107-
typeOf<LocalDateTime>(), typeOf<kotlinx.datetime.LocalDateTime>() -> Field(column.name(), FieldType(false, ArrowType.Date(DateUnit.MILLISECOND), null), emptyList())
100+
columnType.isSubtypeOf(typeOf<LocalDateTime?>()) || columnType.isSubtypeOf(typeOf<kotlinx.datetime.LocalDateTime?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Date(DateUnit.MILLISECOND), null), emptyList())
108101

109-
typeOf<LocalTime?>() -> Field(column.name(), FieldType(true, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
110-
typeOf<LocalTime>() -> Field(column.name(), FieldType(false, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
102+
columnType.isSubtypeOf(typeOf<LocalTime?>()) -> Field(column.name(), FieldType(nullable, ArrowType.Time(TimeUnit.NANOSECOND, 64), null), emptyList())
111103

112104
else -> {
113105
warningSubscriber("Column ${column.name()} has type ${column.typeClass.java.canonicalName}, will be saved as String")

dataframe-arrow/src/test/kotlin/ArrowKtTest.kt

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import org.junit.Test
1414
import java.io.File
1515
import java.net.URL
1616
import java.time.LocalDate
17+
import kotlin.reflect.typeOf
1718

1819
internal class ArrowKtTest {
1920

@@ -98,51 +99,69 @@ internal class ArrowKtTest {
9899
"Hamburg",
99100
"New York",
100101
"Washington",
101-
"Saint Petersburg"
102+
"Saint Petersburg",
103+
"Vatican"
102104
)),
103105
DataColumn.createValueColumn("affiliation", listOf(
104106
"Germany",
105107
"Germany",
106108
"The USA",
107109
"The USA",
108-
"Russia"
110+
"Russia",
111+
null
109112
)),
110113
DataColumn.createValueColumn("is_capital", listOf(
111114
true,
112115
false,
113116
false,
114117
true,
115-
false
118+
false,
119+
null
116120
)),
117121
DataColumn.createValueColumn("population", listOf(
118122
3_769_495,
119123
1_845_229,
120124
8_467_513,
121125
689_545,
122-
5_377_503
126+
5_377_503,
127+
825
123128
)),
124129
DataColumn.createValueColumn("area", listOf(
125130
891.7,
126131
755.22,
127132
1223.59,
128133
177.0,
129-
1439.0
134+
1439.0,
135+
0.44
130136
)),
131-
// DataColumn.createValueColumn("settled", listOf(
132-
// LocalDate.of(1237, 1, 1),
133-
// LocalDate.of(1189, 5, 7),
134-
// LocalDate.of(1624, 1, 1),
135-
// LocalDate.of(1790, 7, 16),
136-
// LocalDate.of(1703, 5, 27)
137-
// ))
137+
DataColumn.createValueColumn("settled", listOf(
138+
LocalDate.of(1237, 1, 1),
139+
LocalDate.of(1189, 5, 7),
140+
LocalDate.of(1624, 1, 1),
141+
LocalDate.of(1790, 7, 16),
142+
LocalDate.of(1703, 5, 27),
143+
LocalDate.of(1929, 2, 11)
144+
))
138145
)
139146

140147
@Test
141148
fun testWritingGeneral() {
149+
fun assertEstimation(citiesDeserialized: DataFrame<*>) {
150+
citiesDeserialized["name"] shouldBe cities["name"]
151+
citiesDeserialized["affiliation"] shouldBe cities["affiliation"]
152+
citiesDeserialized["is_capital"] shouldBe cities["is_capital"]
153+
citiesDeserialized["population"] shouldBe cities["population"]
154+
citiesDeserialized["area"] shouldBe cities["area"]
155+
citiesDeserialized["settled"].type() shouldBe typeOf<LocalDate>() // cities["settled"].type() refers to FlexibleTypeImpl(LocalDate..LocalDate?) and does not match typeOf<LocalDate>()
156+
citiesDeserialized["settled"].values() shouldBe cities["settled"].values()
157+
}
158+
142159
val testFile = File.createTempFile("cities", "arrow")
143160
cities.arrowWriter().writeArrowFeather(testFile)
144-
DataFrame.readArrowFeather(testFile).shouldBe(cities)
145-
}
161+
assertEstimation(DataFrame.readArrowFeather(testFile))
146162

163+
val testByteArray = cities.arrowWriter().saveArrowIPCToByteArray()
164+
assertEstimation(DataFrame.readArrowIPC(testByteArray))
165+
}
147166

148167
}

0 commit comments

Comments
 (0)