|
1 | 1 | import io.kotest.assertions.throwables.shouldThrow
|
| 2 | +import io.kotest.matchers.collections.shouldContain |
2 | 3 | import io.kotest.matchers.shouldBe
|
3 | 4 | import org.apache.arrow.vector.types.pojo.Schema
|
4 | 5 | import org.apache.arrow.vector.util.Text
|
5 | 6 | import org.jetbrains.kotlinx.dataframe.DataColumn
|
6 | 7 | import org.jetbrains.kotlinx.dataframe.DataFrame
|
7 | 8 | import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions
|
| 9 | +import org.jetbrains.kotlinx.dataframe.api.add |
8 | 10 | import org.jetbrains.kotlinx.dataframe.api.columnOf
|
| 11 | +import org.jetbrains.kotlinx.dataframe.api.convertToBoolean |
| 12 | +import org.jetbrains.kotlinx.dataframe.api.copy |
9 | 13 | import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
|
| 14 | +import org.jetbrains.kotlinx.dataframe.api.map |
10 | 15 | import org.jetbrains.kotlinx.dataframe.api.toColumn
|
11 |
| -import org.jetbrains.kotlinx.dataframe.io.* |
| 16 | +import org.jetbrains.kotlinx.dataframe.api.remove |
| 17 | +import org.jetbrains.kotlinx.dataframe.io.ArrowWriter |
| 18 | +import org.jetbrains.kotlinx.dataframe.io.arrowWriter |
| 19 | +import org.jetbrains.kotlinx.dataframe.io.readArrowFeather |
| 20 | +import org.jetbrains.kotlinx.dataframe.io.readArrowIPC |
| 21 | +import org.jetbrains.kotlinx.dataframe.io.writeArrowFeather |
| 22 | +import org.jetbrains.kotlinx.dataframe.io.saveArrowIPCToByteArray |
12 | 23 | import org.junit.Test
|
13 | 24 | import java.io.File
|
14 | 25 | import java.net.URL
|
@@ -112,19 +123,120 @@ internal class ArrowKtTest {
|
112 | 123 | citiesExampleFrame.writeArrowFeather(testFile)
|
113 | 124 | assertEstimation(DataFrame.readArrowFeather(testFile))
|
114 | 125 |
|
115 |
| - val testByteArray = citiesExampleFrame.arrowWriter().saveArrowIPCToByteArray() |
| 126 | + val testByteArray = citiesExampleFrame.saveArrowIPCToByteArray() |
116 | 127 | assertEstimation(DataFrame.readArrowIPC(testByteArray))
|
117 | 128 | }
|
118 | 129 |
|
119 | 130 | @Test
|
120 | 131 | fun testWritingBySchema() {
|
121 | 132 | val testFile = File.createTempFile("cities", "arrow")
|
122 |
| - citiesExampleFrame.arrowWriter(Schema.fromJSON(citiesExampleSchema)).writeArrowFeather(testFile) |
| 133 | + citiesExampleFrame.arrowWriter(Schema.fromJSON(citiesExampleSchema)).use { it.writeArrowFeather(testFile) } |
123 | 134 | val citiesDeserialized = DataFrame.readArrowFeather(testFile, NullabilityOptions.Checking)
|
124 | 135 | citiesDeserialized["population"].type() shouldBe typeOf<Long?>()
|
125 | 136 | citiesDeserialized["area"].type() shouldBe typeOf<Float>()
|
126 | 137 | citiesDeserialized["settled"].type() shouldBe typeOf<LocalDateTime>()
|
127 |
| - shouldThrow<IllegalArgumentException> { citiesDeserialized["page_in_wiki"] shouldBe null } |
| 138 | + shouldThrow<IllegalArgumentException> { citiesDeserialized["page_in_wiki"] } |
128 | 139 | citiesDeserialized["film_in_youtube"] shouldBe DataColumn.createValueColumn("film_in_youtube", arrayOfNulls<String>(citiesExampleFrame.rowsCount()).asList())
|
129 | 140 | }
|
| 141 | + |
| 142 | + @Test |
| 143 | + fun testWidening() { |
| 144 | + val warnings = ArrayList<String>() |
| 145 | + val testRestrictWidening = citiesExampleFrame.arrowWriter( |
| 146 | + Schema.fromJSON(citiesExampleSchema), |
| 147 | + ArrowWriter.Companion.Mode.STRICT |
| 148 | + ) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() } |
| 149 | + warnings.shouldContain("Column \"page_in_wiki\" is not described in target schema and was ignored") |
| 150 | + shouldThrow<IllegalArgumentException> { DataFrame.readArrowFeather(testRestrictWidening)["page_in_wiki"] } |
| 151 | + |
| 152 | + val testAllowWidening = citiesExampleFrame.arrowWriter( |
| 153 | + Schema.fromJSON(citiesExampleSchema), |
| 154 | + ArrowWriter.Companion.Mode( |
| 155 | + restrictWidening = false, |
| 156 | + restrictNarrowing = true, |
| 157 | + strictType = true, |
| 158 | + strictNullable = true |
| 159 | + ) |
| 160 | + ).use { it.saveArrowFeatherToByteArray() } |
| 161 | + DataFrame.readArrowFeather(testAllowWidening)["page_in_wiki"].values() shouldBe citiesExampleFrame["page_in_wiki"].values().map { it.toString() } |
| 162 | + } |
| 163 | + |
| 164 | + @Test |
| 165 | + fun testNarrowing() { |
| 166 | + val frameWithoutRequiredField = citiesExampleFrame.copy().remove("settled") |
| 167 | + |
| 168 | + frameWithoutRequiredField.arrowWriter( |
| 169 | + Schema.fromJSON(citiesExampleSchema), |
| 170 | + ArrowWriter.Companion.Mode.STRICT |
| 171 | + ).use { |
| 172 | + shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() } |
| 173 | + } |
| 174 | + |
| 175 | + val warnings = ArrayList<String>() |
| 176 | + val testAllowNarrowing = frameWithoutRequiredField.arrowWriter( |
| 177 | + Schema.fromJSON(citiesExampleSchema), |
| 178 | + ArrowWriter.Companion.Mode( |
| 179 | + restrictWidening = true, |
| 180 | + restrictNarrowing = false, |
| 181 | + strictType = true, |
| 182 | + strictNullable = true |
| 183 | + ) |
| 184 | + ) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() } |
| 185 | + warnings.shouldContain("Column \"settled\" is not presented") |
| 186 | + shouldThrow<IllegalArgumentException> { DataFrame.readArrowFeather(testAllowNarrowing)["settled"] } |
| 187 | + } |
| 188 | + |
| 189 | + @Test |
| 190 | + fun testStrictType() { |
| 191 | + val frameRenaming = citiesExampleFrame.copy().remove("settled") |
| 192 | + val frameWithIncompatibleField = frameRenaming.add(frameRenaming["is_capital"].map { value -> value ?: false }.rename("settled").convertToBoolean()) |
| 193 | + |
| 194 | + frameWithIncompatibleField.arrowWriter( |
| 195 | + Schema.fromJSON(citiesExampleSchema), |
| 196 | + ArrowWriter.Companion.Mode.STRICT |
| 197 | + ).use { |
| 198 | + shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() } |
| 199 | + } |
| 200 | + |
| 201 | + val warnings = ArrayList<String>() |
| 202 | + val testLoyalType = frameWithIncompatibleField.arrowWriter( |
| 203 | + Schema.fromJSON(citiesExampleSchema), |
| 204 | + ArrowWriter.Companion.Mode( |
| 205 | + restrictWidening = true, |
| 206 | + restrictNarrowing = true, |
| 207 | + strictType = false, |
| 208 | + strictNullable = true |
| 209 | + ) |
| 210 | + ) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() } |
| 211 | + warnings.shouldContain("Type converter from kotlin.Boolean to kotlinx.datetime.LocalDateTime? is not found") |
| 212 | + DataFrame.readArrowFeather(testLoyalType)["settled"].type() shouldBe typeOf<Boolean>() |
| 213 | + } |
| 214 | + |
| 215 | + @Test |
| 216 | + fun testStrictNullable() { |
| 217 | + val frameRenaming = citiesExampleFrame.copy().remove("settled") |
| 218 | + val frameWithNulls = frameRenaming.add(DataColumn.createValueColumn("settled", arrayOfNulls<LocalDate>(frameRenaming.rowsCount()).asList())) |
| 219 | + |
| 220 | + frameWithNulls.arrowWriter( |
| 221 | + Schema.fromJSON(citiesExampleSchema), |
| 222 | + ArrowWriter.Companion.Mode.STRICT |
| 223 | + ).use { |
| 224 | + shouldThrow<IllegalArgumentException> { it.saveArrowFeatherToByteArray() } |
| 225 | + } |
| 226 | + |
| 227 | + val warnings = ArrayList<String>() |
| 228 | + val testLoyalNullable = frameWithNulls.arrowWriter( |
| 229 | + Schema.fromJSON(citiesExampleSchema), |
| 230 | + ArrowWriter.Companion.Mode( |
| 231 | + restrictWidening = true, |
| 232 | + restrictNarrowing = true, |
| 233 | + strictType = true, |
| 234 | + strictNullable = false |
| 235 | + ) |
| 236 | + ) { warning -> warnings.add(warning) }.use { it.saveArrowFeatherToByteArray() } |
| 237 | + warnings.shouldContain("Column \"settled\" contains nulls but expected not nullable") |
| 238 | + DataFrame.readArrowFeather(testLoyalNullable)["settled"].type() shouldBe typeOf<LocalDateTime?>() |
| 239 | + DataFrame.readArrowFeather(testLoyalNullable)["settled"].values() shouldBe arrayOfNulls<LocalDate>(frameRenaming.rowsCount()).asList() |
| 240 | + } |
| 241 | + |
130 | 242 | }
|
0 commit comments