Skip to content

Commit e8ce4b1

Browse files
authored
Merge pull request #133 from Kopilov/converting
More Converting operations
2 parents 62d327b + d30824a commit e8ce4b1

File tree

10 files changed

+275
-73
lines changed

10 files changed

+275
-73
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/convert.kt

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dataframe.RowValueExpression
1616
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1717
import org.jetbrains.kotlinx.dataframe.dataTypes.IFRAME
1818
import org.jetbrains.kotlinx.dataframe.dataTypes.IMG
19+
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
1920
import org.jetbrains.kotlinx.dataframe.impl.api.Parsers
2021
import org.jetbrains.kotlinx.dataframe.impl.api.convertRowColumnImpl
2122
import org.jetbrains.kotlinx.dataframe.impl.api.convertToTypeImpl
@@ -30,7 +31,7 @@ import org.jetbrains.kotlinx.dataframe.io.toDataFrame
3031
import java.math.BigDecimal
3132
import java.net.URL
3233
import java.time.LocalTime
33-
import java.util.*
34+
import java.util.Locale
3435
import kotlin.reflect.KProperty
3536
import kotlin.reflect.KType
3637
import kotlin.reflect.typeOf
@@ -99,7 +100,11 @@ public fun <T, C> Convert<T, C>.to(columnConverter: DataFrame<T>.(DataColumn<C>)
99100
df.replace(columns).with { columnConverter(df, it) }
100101

101102
public inline fun <reified C> AnyCol.convertTo(): DataColumn<C> = convertTo(typeOf<C>()) as DataColumn<C>
102-
public fun AnyCol.convertTo(newType: KType): AnyCol = convertToTypeImpl(newType)
103+
public fun AnyCol.convertTo(newType: KType): AnyCol {
104+
if (this.type() == typeOf<String>() && newType == typeOf<Double>()) return (this as DataColumn<String>).convertToDouble()
105+
if (this.type() == typeOf<String?>() && newType == typeOf<Double?>()) return (this as DataColumn<String?>).convertToDouble()
106+
return convertToTypeImpl(newType)
107+
}
103108

104109
@JvmName("convertToLocalDateTimeFromT")
105110
public fun <T : Any> DataColumn<T>.convertToLocalDateTime(): DataColumn<LocalDateTime> = convertTo()
@@ -125,6 +130,37 @@ public fun <T : Any> DataColumn<T?>.convertToString(): DataColumn<String?> = con
125130
public fun <T : Any> DataColumn<T>.convertToDouble(): DataColumn<Double> = convertTo()
126131
public fun <T : Any> DataColumn<T?>.convertToDouble(): DataColumn<Double?> = convertTo()
127132

133+
/**
134+
* Parse String column to Double considering locale (number format).
135+
* If [locale] parameter is defined, it's number format is used for parsing.
136+
* If [locale] parameter is null, the current system locale is used. If column can not be parsed, then POSIX format is used.
137+
*/
138+
@JvmName("convertToDoubleFromString")
139+
public fun DataColumn<String>.convertToDouble(locale: Locale? = null): DataColumn<Double> {
140+
return this.castToNullable().convertToDouble(locale).castToNotNullable()
141+
}
142+
143+
/**
144+
* Parse String column to Double considering locale (number format).
145+
* If [locale] parameter is defined, it's number format is used for parsing.
146+
* If [locale] parameter is null, the current system locale is used. If column can not be parsed, then POSIX format is used.
147+
*/
148+
@JvmName("convertToDoubleFromStringNullable")
149+
public fun DataColumn<String?>.convertToDouble(locale: Locale? = null): DataColumn<Double?> {
150+
if (locale != null) {
151+
val explicitParser = Parsers.getDoubleParser(locale)
152+
return map { it?.let { explicitParser(it.trim()) ?: throw TypeConversionException(it, typeOf<String>(), typeOf<Double>()) } }
153+
} else {
154+
return try {
155+
val defaultParser = Parsers.getDoubleParser()
156+
map { it?.let { defaultParser(it.trim()) ?: throw TypeConversionException(it, typeOf<String>(), typeOf<Double>()) } }
157+
} catch (e: TypeConversionException) {
158+
val posixParser = Parsers.getDoubleParser(Locale.forLanguageTag("C.UTF-8"))
159+
map { it?.let { posixParser(it.trim()) ?: throw TypeConversionException(it, typeOf<String>(), typeOf<Double>()) } }
160+
}
161+
}
162+
}
163+
128164
@JvmName("convertToFloatFromT")
129165
public fun <T : Any> DataColumn<T>.convertToFloat(): DataColumn<Float> = convertTo()
130166
public fun <T : Any> DataColumn<T?>.convertToFloat(): DataColumn<Float?> = convertTo()

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/convert.kt

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.jetbrains.kotlinx.dataframe.type
3636
import java.math.BigDecimal
3737
import java.net.URL
3838
import java.time.LocalTime
39+
import java.util.Locale
3940
import kotlin.math.roundToInt
4041
import kotlin.math.roundToLong
4142
import kotlin.reflect.KType
@@ -81,15 +82,15 @@ internal fun AnyCol.convertToTypeImpl(to: KType): AnyCol {
8182
return when {
8283
from == to -> this
8384
from.isSubtypeOf(to) -> (this as DataColumnInternal<*>).changeType(to.withNullability(hasNulls()))
84-
else -> when (val converter = getConverter(from, to)) {
85+
else -> when (val converter = getConverter(from, to, ParserOptions(locale = Locale.getDefault()))) {
8586
null -> when (from.classifier) {
8687
Any::class, Number::class, java.io.Serializable::class -> {
8788
// find converter for every value
8889
val values = values.map {
8990
it?.let {
9091
val clazz = it.javaClass.kotlin
9192
val type = clazz.createStarProjectedType(false)
92-
val converter = getConverter(type, to) ?: throw TypeConverterNotFoundException(from, to)
93+
val converter = getConverter(type, to, ParserOptions(locale = Locale.getDefault())) ?: throw TypeConverterNotFoundException(from, to)
9394
converter(it)
9495
}.checkNulls()
9596
}
@@ -107,9 +108,9 @@ internal fun AnyCol.convertToTypeImpl(to: KType): AnyCol {
107108
}
108109
}
109110

110-
internal val convertersCache = mutableMapOf<Pair<KType, KType>, TypeConverter?>()
111+
internal val convertersCache = mutableMapOf<Triple<KType, KType, ParserOptions?>, TypeConverter?>()
111112

112-
internal fun getConverter(from: KType, to: KType): TypeConverter? = convertersCache.getOrPut(from to to) { createConverter(from, to) }
113+
internal fun getConverter(from: KType, to: KType, options: ParserOptions? = null): TypeConverter? = convertersCache.getOrPut(Triple(from, to, options)) { createConverter(from, to, options) }
113114

114115
internal typealias TypeConverter = (Any) -> Any?
115116

@@ -205,6 +206,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
205206
Byte::class -> convert<Number> { it.toByte() }
206207
Short::class -> convert<Number> { it.toShort() }
207208
Long::class -> convert<Number> { it.toLong() }
209+
Boolean::class -> convert<Number> { it.toDouble() != 0.0 }
208210
else -> null
209211
}
210212
Int::class -> when (toClass) {
@@ -214,6 +216,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
214216
Short::class -> convert<Int> { it.toShort() }
215217
Long::class -> convert<Int> { it.toLong() }
216218
BigDecimal::class -> convert<Int> { it.toBigDecimal() }
219+
Boolean::class -> convert<Int> { it != 0 }
217220
LocalDateTime::class -> convert<Int> { it.toLong().toLocalDateTime(defaultTimeZone) }
218221
LocalDate::class -> convert<Int> { it.toLong().toLocalDate(defaultTimeZone) }
219222
java.time.LocalDateTime::class -> convert<Long> { it.toLocalDateTime(defaultTimeZone).toJavaLocalDateTime() }
@@ -227,6 +230,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
227230
Long::class -> convert<Double> { it.roundToLong() }
228231
Short::class -> convert<Double> { it.roundToInt().toShort() }
229232
BigDecimal::class -> convert<Double> { it.toBigDecimal() }
233+
Boolean::class -> convert<Double> { it != 0.0 }
230234
else -> null
231235
}
232236
Long::class -> when (toClass) {
@@ -236,6 +240,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
236240
Short::class -> convert<Long> { it.toShort() }
237241
Int::class -> convert<Long> { it.toInt() }
238242
BigDecimal::class -> convert<Long> { it.toBigDecimal() }
243+
Boolean::class -> convert<Long> { it != 0L }
239244
LocalDateTime::class -> convert<Long> { it.toLocalDateTime(defaultTimeZone) }
240245
LocalDate::class -> convert<Long> { it.toLocalDate(defaultTimeZone) }
241246
Instant::class -> convert<Long> { Instant.fromEpochMilliseconds(it) }
@@ -270,13 +275,15 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
270275
Int::class -> convert<Float> { it.roundToInt() }
271276
Short::class -> convert<Float> { it.roundToInt().toShort() }
272277
BigDecimal::class -> convert<Float> { it.toBigDecimal() }
278+
Boolean::class -> convert<Float> { it != 0.0F }
273279
else -> null
274280
}
275281
BigDecimal::class -> when (toClass) {
276282
Double::class -> convert<BigDecimal> { it.toDouble() }
277283
Int::class -> convert<BigDecimal> { it.toInt() }
278284
Float::class -> convert<BigDecimal> { it.toFloat() }
279285
Long::class -> convert<BigDecimal> { it.toLong() }
286+
Boolean::class -> convert<BigDecimal> { it != BigDecimal.ZERO }
280287
else -> null
281288
}
282289
LocalDateTime::class -> when (toClass) {
@@ -285,6 +292,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
285292
Long::class -> convert<LocalDateTime> { it.toInstant(defaultTimeZone).toEpochMilliseconds() }
286293
java.time.LocalDateTime::class -> convert<LocalDateTime> { it.toJavaLocalDateTime() }
287294
java.time.LocalDate::class -> convert<LocalDateTime> { it.date.toJavaLocalDate() }
295+
java.time.LocalTime::class -> convert<LocalDateTime> { it.toJavaLocalDateTime().toLocalTime() }
288296
else -> null
289297
}
290298
java.time.LocalDateTime::class -> when (toClass) {
@@ -293,6 +301,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
293301
Instant::class -> convert<java.time.LocalDateTime> { it.toKotlinLocalDateTime().toInstant(defaultTimeZone) }
294302
Long::class -> convert<java.time.LocalDateTime> { it.toKotlinLocalDateTime().toInstant(defaultTimeZone).toEpochMilliseconds() }
295303
java.time.LocalDate::class -> convert<java.time.LocalDateTime> { it.toLocalDate() }
304+
java.time.LocalTime::class -> convert<java.time.LocalDateTime> { it.toLocalTime() }
296305
else -> null
297306
}
298307
LocalDate::class -> when (toClass) {

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/parse.kt

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ internal object Parsers : GlobalParserOptions {
194194
inline fun <reified T : Any> stringParserWithOptions(noinline body: (ParserOptions?) -> ((String) -> T?)) =
195195
StringParserWithFormat(typeOf<T>(), body)
196196

197+
private val parserToDoubleWithOptions = stringParserWithOptions { options ->
198+
val numberFormat = NumberFormat.getInstance(options?.locale ?: Locale.getDefault())
199+
val parser = { it: String -> it.parseDouble(numberFormat) }
200+
parser
201+
}
202+
197203
private val parsersOrder = listOf(
198204
stringParser { it.toIntOrNull() },
199205
stringParser { it.toLongOrNull() },
@@ -226,12 +232,12 @@ internal object Parsers : GlobalParserOptions {
226232

227233
stringParser { it.toUrlOrNull() },
228234

229-
stringParserWithOptions { options ->
235+
// Double, with explicit number format or taken from current locale
236+
parserToDoubleWithOptions,
237+
238+
// Double, with POSIX format
239+
stringParser { it.parseDouble(NumberFormat.getInstance(Locale.forLanguageTag("C.UTF-8"))) },
230240

231-
val numberFormat = NumberFormat.getInstance(options?.locale ?: Locale.getDefault())
232-
val parser = { it: String -> it.parseDouble(numberFormat) }
233-
parser
234-
},
235241
stringParser { it.toBooleanOrNull() },
236242
stringParser { it.toBigDecimalOrNull() },
237243

@@ -266,6 +272,13 @@ internal object Parsers : GlobalParserOptions {
266272
) else null
267273
return parser.applyOptions(options)
268274
}
275+
276+
internal fun getDoubleParser(locale: Locale? = null): (String) -> Double? {
277+
val options = if (locale != null) ParserOptions(
278+
locale = locale
279+
) else null
280+
return parserToDoubleWithOptions.applyOptions(options)
281+
}
269282
}
270283

271284
internal fun DataColumn<String?>.tryParseImpl(options: ParserOptions?): DataColumn<*> {

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/CsvTests.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,27 @@ class CsvTests {
104104
assertColumnType("quality", Int::class)
105105
}
106106

107+
@Test
108+
fun `read standard CSV with floats when user has alternative locale`() {
109+
val currentLocale = Locale.getDefault()
110+
try {
111+
Locale.setDefault(Locale.forLanguageTag("ru-RU"))
112+
val df = DataFrame.readCSV(wineCsv, delimiter = ';')
113+
val schema = df.schema()
114+
fun assertColumnType(columnName: String, kClass: KClass<*>) {
115+
val col = schema.columns[columnName]
116+
col.shouldNotBeNull()
117+
col.type.classifier shouldBe kClass
118+
}
119+
120+
assertColumnType("citric acid", Double::class)
121+
assertColumnType("alcohol", Double::class)
122+
assertColumnType("quality", Int::class)
123+
} finally {
124+
Locale.setDefault(currentLocale)
125+
}
126+
}
127+
107128
@Test
108129
fun `read with custom header`() {
109130
val header = ('A'..'K').map { it.toString() }

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ParserTests.kt

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
package org.jetbrains.kotlinx.dataframe.io
22

3+
import io.kotest.assertions.throwables.shouldThrow
34
import io.kotest.matchers.shouldBe
45
import kotlinx.datetime.LocalDateTime
6+
import org.jetbrains.kotlinx.dataframe.DataColumn
57
import org.jetbrains.kotlinx.dataframe.DataFrame
6-
import org.jetbrains.kotlinx.dataframe.api.*
8+
import org.jetbrains.kotlinx.dataframe.api.cast
79
import org.jetbrains.kotlinx.dataframe.api.columnOf
10+
import org.jetbrains.kotlinx.dataframe.api.convertTo
11+
import org.jetbrains.kotlinx.dataframe.api.convertToDouble
12+
import org.jetbrains.kotlinx.dataframe.api.parse
13+
import org.jetbrains.kotlinx.dataframe.api.parser
14+
import org.jetbrains.kotlinx.dataframe.api.tryParse
815
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
916
import org.junit.Test
1017
import java.math.BigDecimal
18+
import java.util.Locale
1119
import kotlin.reflect.typeOf
1220

1321
class ParserTests {
@@ -58,4 +66,88 @@ class ParserTests {
5866
converted[0] shouldBe 1.0f
5967
converted[1] shouldBe 0.321f
6068
}
69+
70+
@Test
71+
fun `convert to Boolean`() {
72+
val col by columnOf(BigDecimal(1.0), BigDecimal(0.0), 0, 1, 10L, 0.0, 0.1)
73+
col.convertTo<Boolean>().shouldBe(
74+
DataColumn.createValueColumn("col", listOf(true, false, false, true, true, false, true), typeOf<Boolean>())
75+
)
76+
}
77+
78+
@Test
79+
fun `converting String to Double in different locales`() {
80+
val currentLocale = Locale.getDefault()
81+
try {
82+
// Test 36 behaviour combinations:
83+
84+
// 3 source columns
85+
val columnDot = columnOf("12.345", "67.890")
86+
val columnComma = columnOf("12,345", "67,890")
87+
val columnMixed = columnOf("12.345", "67,890")
88+
// *
89+
// (3 locales as converting parameter + original converting)
90+
val parsingLocaleNotDefined: Locale? = null
91+
val parsingLocaleUsesDot: Locale = Locale.forLanguageTag("en-US")
92+
val parsingLocaleUsesComma: Locale = Locale.forLanguageTag("ru-RU")
93+
// *
94+
// 3 system locales
95+
96+
Locale.setDefault(Locale.forLanguageTag("C.UTF-8"))
97+
98+
columnDot.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
99+
columnComma.convertTo<Double>().shouldBe(columnOf(12345.0, 67890.0))
100+
columnMixed.convertTo<Double>().shouldBe(columnOf(12.345, 67890.0))
101+
102+
columnDot.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
103+
columnComma.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12345.0, 67890.0))
104+
columnMixed.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67890.0))
105+
106+
columnDot.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67.89))
107+
columnComma.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12345.0, 67890.0))
108+
columnMixed.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67890.0))
109+
110+
shouldThrow<TypeConversionException> { columnDot.convertToDouble(parsingLocaleUsesComma) }
111+
columnComma.convertToDouble(parsingLocaleUsesComma).shouldBe(columnOf(12.345, 67.89))
112+
shouldThrow<TypeConversionException> { columnMixed.convertToDouble(parsingLocaleUsesComma) }
113+
114+
Locale.setDefault(Locale.forLanguageTag("en-US"))
115+
116+
columnDot.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
117+
columnComma.convertTo<Double>().shouldBe(columnOf(12345.0, 67890.0))
118+
columnMixed.convertTo<Double>().shouldBe(columnOf(12.345, 67890.0))
119+
120+
columnDot.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
121+
columnComma.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12345.0, 67890.0))
122+
columnMixed.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67890.0))
123+
124+
columnDot.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67.89))
125+
columnComma.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12345.0, 67890.0))
126+
columnMixed.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67890.0))
127+
128+
shouldThrow<TypeConversionException> { columnDot.convertToDouble(parsingLocaleUsesComma) }
129+
columnComma.convertToDouble(parsingLocaleUsesComma).shouldBe(columnOf(12.345, 67.89))
130+
shouldThrow<TypeConversionException> { columnMixed.convertToDouble(parsingLocaleUsesComma) }
131+
132+
Locale.setDefault(Locale.forLanguageTag("ru-RU"))
133+
134+
columnDot.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
135+
columnComma.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
136+
columnMixed.convertTo<Double>().shouldBe(columnOf(12.345, 67890.0))
137+
138+
columnDot.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
139+
columnComma.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
140+
columnMixed.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67890.0))
141+
142+
columnDot.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67.89))
143+
columnComma.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12345.0, 67890.0))
144+
columnMixed.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67890.0))
145+
146+
shouldThrow<TypeConversionException> { columnDot.convertToDouble(parsingLocaleUsesComma) }
147+
columnComma.convertToDouble(parsingLocaleUsesComma).shouldBe(columnOf(12.345, 67.89))
148+
shouldThrow<TypeConversionException> { columnMixed.convertToDouble(parsingLocaleUsesComma) }
149+
} finally {
150+
Locale.setDefault(currentLocale)
151+
}
152+
}
61153
}

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/SampleNotebooksTests.kt

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.jetbrains.jupyter.parser.notebook.Output
66
import org.junit.Ignore
77
import org.junit.Test
88
import java.io.File
9+
import java.util.Locale
910

1011
class SampleNotebooksTests : DataFrameJupyterTest() {
1112
@Test
@@ -39,13 +40,23 @@ class SampleNotebooksTests : DataFrameJupyterTest() {
3940
)
4041

4142
@Test
42-
fun netflix() = exampleTest(
43-
"netflix",
44-
replacer = CodeReplacer.byMap(
45-
testFile("netflix", "country_codes.csv"),
46-
testFile("netflix", "netflix_titles.csv"),
47-
)
48-
)
43+
fun netflix() {
44+
val currentLocale = Locale.getDefault()
45+
try {
46+
// Set explicit locale as of test data contains locale-dependent values (date for parsing)
47+
Locale.setDefault(Locale.forLanguageTag("en-US"))
48+
49+
exampleTest(
50+
"netflix",
51+
replacer = CodeReplacer.byMap(
52+
testFile("netflix", "country_codes.csv"),
53+
testFile("netflix", "netflix_titles.csv"),
54+
)
55+
)
56+
} finally {
57+
Locale.setDefault(currentLocale)
58+
}
59+
}
4960

5061
@Test
5162
@Ignore

0 commit comments

Comments
 (0)