Skip to content

Commit 8ccd85b

Browse files
committed
Nullable checking beta
1 parent 49c1cbb commit 8ccd85b

File tree

10 files changed

+128
-74
lines changed

10 files changed

+128
-74
lines changed

dataframe-arrow/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies {
1212
implementation(libs.arrow.format)
1313
implementation(libs.arrow.memory)
1414
implementation(libs.commonsCompress)
15+
implementation(libs.kotlin.reflect)
1516

1617
testApi(project(":core"))
1718
testImplementation(libs.junit)

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ import java.time.Duration
5858
import java.time.LocalDate
5959
import java.time.LocalDateTime
6060
import java.time.LocalTime
61+
import kotlin.reflect.KType
62+
import kotlin.reflect.full.withNullability
6163
import kotlin.reflect.typeOf
6264

6365
public class ArrowFeather : SupportedFormat {
@@ -141,7 +143,7 @@ public fun DataFrame.Companion.readArrowFeather(channel: SeekableByteChannel, al
141143
private fun BitVector.values(range: IntRange): List<Boolean?> = range.map { getObject(it) }
142144

143145
private fun UInt1Vector.values(range: IntRange): List<Short?> = range.map { getObjectNoOverflow(it) }
144-
private fun UInt2Vector.values(range: IntRange): List<Int?> = range.map { getObject(it).code }
146+
private fun UInt2Vector.values(range: IntRange): List<Int?> = range.map { getObject(it)?.code }
145147
private fun UInt4Vector.values(range: IntRange): List<Long?> = range.map { getObjectNoOverflow(it) }
146148
private fun UInt8Vector.values(range: IntRange): List<BigInteger?> = range.map { getObjectNoOverflow(it) }
147149

@@ -158,14 +160,33 @@ private fun Float8Vector.values(range: IntRange): List<Double?> = range.map { ge
158160

159161
private fun DurationVector.values(range: IntRange): List<Duration?> = range.map { getObject(it) }
160162
private fun DateDayVector.values(range: IntRange): List<LocalDate?> = range.map {
163+
if (getObject(it) == null) null else
161164
DateUtility.getLocalDateTimeFromEpochMilli(getObject(it).toLong() * DateUtility.daysToStandardMillis).toLocalDate()
162165
}
163166
private fun DateMilliVector.values(range: IntRange): List<LocalDateTime?> = range.map { getObject(it) }
164167

165-
private fun TimeNanoVector.values(range: IntRange): List<LocalTime?> = range.map { LocalTime.ofNanoOfDay(get(it)) }
166-
private fun TimeMicroVector.values(range: IntRange): List<LocalTime?> = range.map { LocalTime.ofNanoOfDay(get(it) * 1000) }
167-
private fun TimeMilliVector.values(range: IntRange): List<LocalTime?> = range.map { LocalTime.ofNanoOfDay(get(it).toLong() * 1000_000) }
168-
private fun TimeSecVector.values(range: IntRange): List<LocalTime?> = range.map { LocalTime.ofSecondOfDay(get(it).toLong()) }
168+
private fun TimeNanoVector.values(range: IntRange): List<LocalTime?> = range.mapIndexed { i, it ->
169+
if (isNull(i)) {
170+
null
171+
} else {
172+
LocalTime.ofNanoOfDay(get(it))
173+
}
174+
}
175+
private fun TimeMicroVector.values(range: IntRange): List<LocalTime?> = range.mapIndexed { i, it ->
176+
if (isNull(i)) {
177+
null
178+
} else {
179+
LocalTime.ofNanoOfDay(getObject(it) * 1000)
180+
}
181+
}
182+
private fun TimeMilliVector.values(range: IntRange): List<LocalTime?> = range.mapIndexed { i, it ->
183+
if (isNull(i)) {
184+
null
185+
} else {
186+
LocalTime.ofNanoOfDay(get(it).toLong() * 1000_000)
187+
}
188+
}
189+
private fun TimeSecVector.values(range: IntRange): List<LocalTime?> = range.map { getObject(it)?.let {LocalTime.ofSecondOfDay(it.toLong())} }
169190

170191
private fun StructVector.values(range: IntRange): List<Map<String, Any?>?> = range.map { getObject(it) }
171192

@@ -201,36 +222,36 @@ private fun LargeVarCharVector.values(range: IntRange): List<String?> = range.ma
201222
}
202223
}
203224

204-
private inline fun <reified T> List<T>.withType() = this to typeOf<T>()
225+
private inline fun <reified T> List<T>.withType(nullability: Boolean) = this to typeOf<T>().withNullability(nullability)
205226

206227
private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol {
207228
val range = 0 until root.rowCount
208229
val (list, type) = when (val vector = root.getVector(field)) {
209-
is VarCharVector -> vector.values(range).withType()
210-
is LargeVarCharVector -> vector.values(range).withType()
211-
is VarBinaryVector -> vector.values(range).withType()
212-
is LargeVarBinaryVector -> vector.values(range).withType()
213-
is BitVector -> vector.values(range).withType()
214-
is SmallIntVector -> vector.values(range).withType()
215-
is TinyIntVector -> vector.values(range).withType()
216-
is UInt1Vector -> vector.values(range).withType()
217-
is UInt2Vector -> vector.values(range).withType()
218-
is UInt4Vector -> vector.values(range).withType()
219-
is UInt8Vector -> vector.values(range).withType()
220-
is IntVector -> vector.values(range).withType()
221-
is BigIntVector -> vector.values(range).withType()
222-
is DecimalVector -> vector.values(range).withType()
223-
is Decimal256Vector -> vector.values(range).withType()
224-
is Float8Vector -> vector.values(range).withType()
225-
is Float4Vector -> vector.values(range).withType()
226-
is DurationVector -> vector.values(range).withType()
227-
is DateDayVector -> vector.values(range).withType()
228-
is DateMilliVector -> vector.values(range).withType()
229-
is TimeNanoVector -> vector.values(range).withType()
230-
is TimeMicroVector -> vector.values(range).withType()
231-
is TimeMilliVector -> vector.values(range).withType()
232-
is TimeSecVector -> vector.values(range).withType()
233-
is StructVector -> vector.values(range).withType()
230+
is VarCharVector -> vector.values(range).withType(field.isNullable)
231+
is LargeVarCharVector -> vector.values(range).withType(field.isNullable)
232+
is VarBinaryVector -> vector.values(range).withType(field.isNullable)
233+
is LargeVarBinaryVector -> vector.values(range).withType(field.isNullable)
234+
is BitVector -> vector.values(range).withType(field.isNullable)
235+
is SmallIntVector -> vector.values(range).withType(field.isNullable)
236+
is TinyIntVector -> vector.values(range).withType(field.isNullable)
237+
is UInt1Vector -> vector.values(range).withType(field.isNullable)
238+
is UInt2Vector -> vector.values(range).withType(field.isNullable)
239+
is UInt4Vector -> vector.values(range).withType(field.isNullable)
240+
is UInt8Vector -> vector.values(range).withType(field.isNullable)
241+
is IntVector -> vector.values(range).withType(field.isNullable)
242+
is BigIntVector -> vector.values(range).withType(field.isNullable)
243+
is DecimalVector -> vector.values(range).withType(field.isNullable)
244+
is Decimal256Vector -> vector.values(range).withType(field.isNullable)
245+
is Float8Vector -> vector.values(range).withType(field.isNullable)
246+
is Float4Vector -> vector.values(range).withType(field.isNullable)
247+
is DurationVector -> vector.values(range).withType(field.isNullable)
248+
is DateDayVector -> vector.values(range).withType(field.isNullable)
249+
is DateMilliVector -> vector.values(range).withType(field.isNullable)
250+
is TimeNanoVector -> vector.values(range).withType(field.isNullable)
251+
is TimeMicroVector -> vector.values(range).withType(field.isNullable)
252+
is TimeMilliVector -> vector.values(range).withType(field.isNullable)
253+
is TimeSecVector -> vector.values(range).withType(field.isNullable)
254+
is StructVector -> vector.values(range).withType(field.isNullable)
234255
else -> {
235256
TODO("not fully implemented")
236257
}

dataframe-arrow/src/test/kotlin/ArrowKtTest.kt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,25 @@ internal class ArrowKtTest {
3838

3939
@Test
4040
fun testReadingAllTypesAsEstimated() {
41-
assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test.arrow")))
42-
assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test.arrow")))
41+
assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test.arrow")), true, false)
42+
assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test.arrow")), true, false)
43+
}
44+
45+
@Test
46+
fun testReadingAllTypesAsEstimatedWithNulls() {
47+
assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-with-nulls.arrow")), true, true)
48+
assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-with-nulls.arrow")), true, true)
49+
}
50+
51+
@Test
52+
fun testReadingAllTypesAsEstimatedNotNullable() {
53+
assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-not-nullable.arrow")), false, false)
54+
assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-not-nullable.arrow")), false, false)
55+
}
56+
57+
@Test
58+
fun testReadingAllTypesAsEstimatedNotNullableWithNulls() {
59+
assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-illegal.arrow")), false, true)
60+
assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-illegal.arrow")), false, true)
4361
}
4462
}

dataframe-arrow/src/test/kotlin/exampleEstimatesAssertions.kt

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,137 +9,151 @@ import java.time.LocalTime
99
import java.time.ZoneOffset
1010
import kotlin.math.absoluteValue
1111
import kotlin.math.pow
12+
import kotlin.reflect.full.withNullability
1213
import kotlin.reflect.typeOf
1314

1415
/**
1516
* Assert that we have got the same data that was originally saved on example creation.
1617
*/
17-
internal fun assertEstimations(exampleFrame: AnyFrame) {
18+
internal fun assertEstimations(exampleFrame: AnyFrame, nullable: Boolean, withNulls: Boolean) {
1819
/**
1920
* In [exampleFrame] we get two concatenated batches. To assert the estimations, we should transform frame row number to batch row number
2021
*/
2122
fun iBatch(iFrame: Int): Int {
2223
val firstBatchSize = 100;
2324
return if (iFrame < firstBatchSize) iFrame else iFrame - firstBatchSize
2425
}
26+
27+
fun expectedNull(rowNumber: Int): Boolean {
28+
return (rowNumber + 1) % 5 == 0;
29+
}
30+
31+
fun assertValueOrNull(rowNumber: Int, actual: Any?, expected: Any) {
32+
if (withNulls && expectedNull(rowNumber)) {
33+
actual shouldBe null
34+
} else {
35+
actual shouldBe expected
36+
}
37+
}
38+
2539
val asciiStringCol = exampleFrame["asciiString"] as DataColumn<String?>
26-
asciiStringCol.type() shouldBe typeOf<String?>()
40+
asciiStringCol.type() shouldBe typeOf<String>().withNullability(nullable)
2741
asciiStringCol.forEachIndexed { i, element ->
28-
element shouldBe "Test Example ${iBatch(i)}"
42+
assertValueOrNull(iBatch(i), element, "Test Example ${iBatch(i)}")
2943
}
3044

3145
val utf8StringCol = exampleFrame["utf8String"] as DataColumn<String?>
32-
utf8StringCol.type() shouldBe typeOf<String?>()
46+
utf8StringCol.type() shouldBe typeOf<String>().withNullability(nullable)
3347
utf8StringCol.forEachIndexed { i, element ->
34-
element shouldBe "Тестовый пример ${iBatch(i)}"
48+
assertValueOrNull(iBatch(i), element, "Тестовый пример ${iBatch(i)}")
3549
}
3650

3751
val largeStringCol = exampleFrame["largeString"] as DataColumn<String?>
38-
largeStringCol.type() shouldBe typeOf<String?>()
52+
largeStringCol.type() shouldBe typeOf<String>().withNullability(nullable)
3953
largeStringCol.forEachIndexed { i, element ->
40-
element shouldBe "Test Example Should Be Large ${iBatch(i)}"
54+
assertValueOrNull(iBatch(i), element, "Test Example Should Be Large ${iBatch(i)}")
4155
}
4256

4357
val booleanCol = exampleFrame["boolean"] as DataColumn<Boolean?>
44-
booleanCol.type() shouldBe typeOf<Boolean?>()
58+
booleanCol.type() shouldBe typeOf<Boolean>().withNullability(nullable)
4559
booleanCol.forEachIndexed { i, element ->
46-
element shouldBe (iBatch(i) % 2 == 0)
60+
assertValueOrNull(iBatch(i), element, iBatch(i) % 2 == 0)
4761
}
4862

4963
val byteCol = exampleFrame["byte"] as DataColumn<Byte?>
50-
byteCol.type() shouldBe typeOf<Byte?>()
64+
byteCol.type() shouldBe typeOf<Byte>().withNullability(nullable)
5165
byteCol.forEachIndexed { i, element ->
52-
element shouldBe (iBatch(i) * 10).toByte()
66+
assertValueOrNull(iBatch(i), element, (iBatch(i) * 10).toByte())
5367
}
5468

5569
val shortCol = exampleFrame["short"] as DataColumn<Short?>
56-
shortCol.type() shouldBe typeOf<Short?>()
70+
shortCol.type() shouldBe typeOf<Short>().withNullability(nullable)
5771
shortCol.forEachIndexed { i, element ->
58-
element shouldBe (iBatch(i) * 1000).toShort()
72+
assertValueOrNull(iBatch(i), element, (iBatch(i) * 1000).toShort())
5973
}
6074

6175
val intCol = exampleFrame["int"] as DataColumn<Int?>
62-
intCol.type() shouldBe typeOf<Int?>()
76+
intCol.type() shouldBe typeOf<Int>().withNullability(nullable)
6377
intCol.forEachIndexed { i, element ->
64-
element shouldBe (iBatch(i) * 100000000).toInt()
78+
assertValueOrNull(iBatch(i), element, iBatch(i) * 100000000)
6579
}
6680

6781
val longCol = exampleFrame["longInt"] as DataColumn<Long?>
68-
longCol.type() shouldBe typeOf<Long?>()
82+
longCol.type() shouldBe typeOf<Long>().withNullability(nullable)
6983
longCol.forEachIndexed { i, element ->
70-
element shouldBe iBatch(i) * 100000000000000000L
84+
assertValueOrNull(iBatch(i), element, iBatch(i) * 100000000000000000L)
7185
}
7286

7387
val unsignedByteCol = exampleFrame["unsigned_byte"] as DataColumn<Short?>
74-
unsignedByteCol.type() shouldBe typeOf<Short?>()
88+
unsignedByteCol.type() shouldBe typeOf<Short>().withNullability(nullable)
7589
unsignedByteCol.forEachIndexed { i, element ->
76-
element shouldBe (iBatch(i) * 10 % (Byte.MIN_VALUE.toShort() * 2).absoluteValue).toShort()
90+
assertValueOrNull(iBatch(i), element, (iBatch(i) * 10 % (Byte.MIN_VALUE.toShort() * 2).absoluteValue).toShort())
7791
}
7892

7993
val unsignedShortCol = exampleFrame["unsigned_short"] as DataColumn<Int?>
80-
unsignedShortCol.type() shouldBe typeOf<Int?>()
94+
unsignedShortCol.type() shouldBe typeOf<Int>().withNullability(nullable)
8195
unsignedShortCol.forEachIndexed { i, element ->
82-
element shouldBe (iBatch(i) * 1000 % (Short.MIN_VALUE.toInt() * 2).absoluteValue)
96+
assertValueOrNull(iBatch(i), element, iBatch(i) * 1000 % (Short.MIN_VALUE.toInt() * 2).absoluteValue)
8397
}
8498

8599
val unsignedIntCol = exampleFrame["unsigned_int"] as DataColumn<Long?>
86-
unsignedIntCol.type() shouldBe typeOf<Long?>()
100+
unsignedIntCol.type() shouldBe typeOf<Long>().withNullability(nullable)
87101
unsignedIntCol.forEachIndexed { i, element ->
88-
element shouldBe (iBatch(i).toLong() * 100000000 % (Int.MIN_VALUE.toLong() * 2).absoluteValue)
102+
assertValueOrNull(iBatch(i), element, iBatch(i).toLong() * 100000000 % (Int.MIN_VALUE.toLong() * 2).absoluteValue)
89103
}
90104

91105
val unsignedLongIntCol = exampleFrame["unsigned_longInt"] as DataColumn<BigInteger?>
92-
unsignedLongIntCol.type() shouldBe typeOf<BigInteger?>()
106+
unsignedLongIntCol.type() shouldBe typeOf<BigInteger>().withNullability(nullable)
93107
unsignedLongIntCol.forEachIndexed { i, element ->
94-
element shouldBe (iBatch(i).toBigInteger() * 100000000000000000L.toBigInteger() % (Long.MIN_VALUE.toBigInteger() * 2.toBigInteger()).abs())
108+
assertValueOrNull(iBatch(i), element, iBatch(i).toBigInteger() * 100000000000000000L.toBigInteger() % (Long.MIN_VALUE.toBigInteger() * 2.toBigInteger()).abs())
95109
}
96110

97111
val floatCol = exampleFrame["float"] as DataColumn<Float?>
98-
floatCol.type() shouldBe typeOf<Float?>()
112+
floatCol.type() shouldBe typeOf<Float>().withNullability(nullable)
99113
floatCol.forEachIndexed { i, element ->
100-
element shouldBe (2.0f.pow(iBatch(i).toFloat()))
114+
assertValueOrNull(iBatch(i), element, 2.0f.pow(iBatch(i).toFloat()))
101115
}
102116

103117
val doubleCol = exampleFrame["double"] as DataColumn<Double?>
104-
doubleCol.type() shouldBe typeOf<Double?>()
118+
doubleCol.type() shouldBe typeOf<Double>().withNullability(nullable)
105119
doubleCol.forEachIndexed { i, element ->
106-
element shouldBe (2.0.pow(iBatch(i).toDouble()))
120+
assertValueOrNull(iBatch(i), element, 2.0.pow(iBatch(i)))
107121
}
108122

109123
val dateCol = exampleFrame["date32"] as DataColumn<LocalDate?>
110-
dateCol.type() shouldBe typeOf<LocalDate?>()
124+
dateCol.type() shouldBe typeOf<LocalDate>().withNullability(nullable)
111125
dateCol.forEachIndexed { i, element ->
112-
element shouldBe LocalDate.ofEpochDay(iBatch(i).toLong() * 30)
126+
assertValueOrNull(iBatch(i), element, LocalDate.ofEpochDay(iBatch(i).toLong() * 30))
113127
}
114128

115129
val datetimeCol = exampleFrame["date64"] as DataColumn<LocalDateTime?>
116-
datetimeCol.type() shouldBe typeOf<LocalDateTime?>()
130+
datetimeCol.type() shouldBe typeOf<LocalDateTime>().withNullability(nullable)
117131
datetimeCol.forEachIndexed { i, element ->
118-
element shouldBe LocalDateTime.ofEpochSecond(iBatch(i).toLong() * 60 * 60 * 24 * 30, 0, ZoneOffset.UTC)
132+
assertValueOrNull(iBatch(i), element, LocalDateTime.ofEpochSecond(iBatch(i).toLong() * 60 * 60 * 24 * 30, 0, ZoneOffset.UTC))
119133
}
120134

121135
val timeSecCol = exampleFrame["time32_seconds"] as DataColumn<LocalTime?>
122-
timeSecCol.type() shouldBe typeOf<LocalTime?>()
136+
timeSecCol.type() shouldBe typeOf<LocalTime>().withNullability(nullable)
123137
timeSecCol.forEachIndexed { i, element ->
124-
element shouldBe LocalTime.ofSecondOfDay(iBatch(i).toLong())
138+
assertValueOrNull(iBatch(i), element, LocalTime.ofSecondOfDay(iBatch(i).toLong()))
125139
}
126140

127141
val timeMilliCol = exampleFrame["time32_milli"] as DataColumn<LocalTime?>
128-
timeMilliCol.type() shouldBe typeOf<LocalTime?>()
142+
timeMilliCol.type() shouldBe typeOf<LocalTime>().withNullability(nullable)
129143
timeMilliCol.forEachIndexed { i, element ->
130-
element shouldBe LocalTime.ofNanoOfDay(iBatch(i).toLong() * 1000_000)
144+
assertValueOrNull(iBatch(i), element, LocalTime.ofNanoOfDay(iBatch(i).toLong() * 1000_000))
131145
}
132146

133147
val timeMicroCol = exampleFrame["time64_micro"] as DataColumn<LocalTime?>
134-
timeMicroCol.type() shouldBe typeOf<LocalTime?>()
148+
timeMicroCol.type() shouldBe typeOf<LocalTime>().withNullability(nullable)
135149
timeMicroCol.forEachIndexed { i, element ->
136-
element shouldBe LocalTime.ofNanoOfDay(iBatch(i).toLong() * 1000)
150+
assertValueOrNull(iBatch(i), element, LocalTime.ofNanoOfDay(iBatch(i).toLong() * 1000))
137151
}
138152

139153
val timeNanoCol = exampleFrame["time64_nano"] as DataColumn<LocalTime?>
140-
timeNanoCol.type() shouldBe typeOf<LocalTime?>()
154+
timeNanoCol.type() shouldBe typeOf<LocalTime>().withNullability(nullable)
141155
timeNanoCol.forEachIndexed { i, element ->
142-
element shouldBe LocalTime.ofNanoOfDay(iBatch(i).toLong())
156+
assertValueOrNull(iBatch(i), element, LocalTime.ofNanoOfDay(iBatch(i).toLong()))
143157
}
144158

145159
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)