Skip to content

Commit f028ebe

Browse files
committed
Save to Arrow test
1 parent 68bf499 commit f028ebe

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowWriter.kt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ public class ArrowWriter(
245245

246246
}
247247

248+
/**
249+
* Create Arrow FieldVector with [column] content cast to [field] type according to [strictType] and [strictNullable] settings.
250+
*/
248251
private fun allocateVectorAndInfill(field: Field, column: AnyCol?, strictType: Boolean, strictNullable: Boolean): FieldVector {
249252
val containNulls = (column == null || column.hasNulls())
250253
// Convert the column to type specified in field. (If we already have target type, convertTo will do nothing)
@@ -264,7 +267,7 @@ public class ArrowWriter(
264267
}
265268
val vector = if (!actualField.isNullable && containNulls) {
266269
if (strictNullable) {
267-
throw Exception("${actualField.name} column contains nulls but should be not nullable")
270+
throw IllegalArgumentException("${actualField.name} column contains nulls but should be not nullable")
268271
} else {
269272
warningSubscriber("${actualField.name} column contains nulls but expected not nullable")
270273
Field(actualField.name, FieldType(true, actualField.fieldType.type, actualField.fieldType.dictionary), actualField.children).createVector(allocator)!!
@@ -296,7 +299,7 @@ public class ArrowWriter(
296299
val column = dataFrame.getColumnOrNull(field.name)
297300
if (column == null && !field.isNullable) {
298301
if (mode.restrictNarrowing) {
299-
throw Exception("${field.name} column is not presented")
302+
throw IllegalArgumentException("${field.name} column is not presented")
300303
} else {
301304
warningSubscriber("${field.name} column is not presented")
302305
continue

dataframe-arrow/src/test/kotlin/ArrowKtTest.kt

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import io.kotest.assertions.throwables.shouldThrow
22
import io.kotest.matchers.shouldBe
33
import org.apache.arrow.vector.util.Text
4+
import org.jetbrains.kotlinx.dataframe.DataColumn
45
import org.jetbrains.kotlinx.dataframe.DataFrame
56
import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions
67
import org.jetbrains.kotlinx.dataframe.api.columnOf
78
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
89
import org.jetbrains.kotlinx.dataframe.api.toColumn
10+
import org.jetbrains.kotlinx.dataframe.io.arrowWriter
911
import org.jetbrains.kotlinx.dataframe.io.readArrowFeather
1012
import org.jetbrains.kotlinx.dataframe.io.readArrowIPC
1113
import org.junit.Test
14+
import java.io.File
1215
import java.net.URL
16+
import java.time.LocalDate
1317

1418
internal class ArrowKtTest {
1519

@@ -87,4 +91,58 @@ internal class ArrowKtTest {
8791
assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-illegal.arrow"), NullabilityOptions.Widening), true, true)
8892
assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-illegal.arrow"), NullabilityOptions.Widening), true, true)
8993
}
94+
95+
val cities = dataFrameOf(
96+
DataColumn.createValueColumn("name", listOf(
97+
"Berlin",
98+
"Hamburg",
99+
"New York",
100+
"Washington",
101+
"Saint Petersburg"
102+
)),
103+
DataColumn.createValueColumn("affiliation", listOf(
104+
"Germany",
105+
"Germany",
106+
"The USA",
107+
"The USA",
108+
"Russia"
109+
)),
110+
DataColumn.createValueColumn("is_capital", listOf(
111+
true,
112+
false,
113+
false,
114+
true,
115+
false
116+
)),
117+
DataColumn.createValueColumn("population", listOf(
118+
3_769_495,
119+
1_845_229,
120+
8_467_513,
121+
689_545,
122+
5_377_503
123+
)),
124+
DataColumn.createValueColumn("area", listOf(
125+
891.7,
126+
755.22,
127+
1223.59,
128+
177.0,
129+
1439.0
130+
)),
131+
// DataColumn.createValueColumn("settled", listOf(
132+
// LocalDate.of(1237, 1, 1),
133+
// LocalDate.of(1189, 5, 7),
134+
// LocalDate.of(1624, 1, 1),
135+
// LocalDate.of(1790, 7, 16),
136+
// LocalDate.of(1703, 5, 27)
137+
// ))
138+
)
139+
140+
@Test
141+
fun testWritingGeneral() {
142+
val testFile = File.createTempFile("cities", "arrow")
143+
cities.arrowWriter().writeArrowFeather(testFile)
144+
DataFrame.readArrowFeather(testFile).shouldBe(cities)
145+
}
146+
147+
90148
}

0 commit comments

Comments
 (0)