Skip to content

Commit 308e456

Browse files
committed
DataFrame reading refactored.
* Remove duplicate code. * Add read functions to `SupportedFormats` enum class. * Rename `headers` argument to `header` * Add `header` argument for JSON readers * Add `DataColumn<Iterable<*>>.splitInto(columnNames)`
1 parent 9a309bf commit 308e456

File tree

11 files changed

+279
-177
lines changed

11 files changed

+279
-177
lines changed

docs/StardustDocs/topics/read.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ DataFrame.readCSV(URL("https://raw.githubusercontent.com/Kotlin/dataframe/master
2525

2626
All `readCSV` overloads support different options.
2727
For example, you can specify custom delimiter if it differs from `,`, charset
28-
and headers names if your CSV is missing them
28+
and column names if your CSV is missing them
2929

3030
<!---FUN readCsvCustom-->
3131

3232
```kotlin
3333
val df = DataFrame.readCSV(
3434
file,
3535
delimiter = '|',
36-
headers = listOf("A", "B", "C", "D"),
36+
header = listOf("A", "B", "C", "D"),
3737
parserOptions = ParserOptions(nullStrings = setOf("not assigned"))
3838
)
3939
```

src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/split.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.jetbrains.kotlinx.dataframe.api
22

33
import org.jetbrains.kotlinx.dataframe.AnyFrame
44
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
5+
import org.jetbrains.kotlinx.dataframe.DataColumn
56
import org.jetbrains.kotlinx.dataframe.DataFrame
67
import org.jetbrains.kotlinx.dataframe.DataRow
78
import org.jetbrains.kotlinx.dataframe.columns.ColumnAccessor
@@ -291,3 +292,9 @@ public inline fun <T, C : Iterable<R>, reified R> Split<T, C>.inplace(): DataFra
291292
public fun <T, C, R> SplitWithTransform<T, C, R>.inplace(): DataFrame<T> = df.convert(columns).splitInplace(tartypeOf, transform)
292293

293294
// endregion
295+
296+
// region DataColumn
297+
298+
public fun DataColumn<Iterable<*>>.splitInto(vararg names: String): AnyFrame = toDataFrame().split { this@splitInto }.into(*names)
299+
300+
// endregion

src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.jetbrains.kotlinx.dataframe.api.Infer
3535
import org.jetbrains.kotlinx.dataframe.api.concat
3636
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
3737
import java.io.File
38+
import java.io.InputStream
3839
import java.math.BigDecimal
3940
import java.math.BigInteger
4041
import java.net.URL
@@ -174,23 +175,16 @@ public fun DataFrame.Companion.readArrow(file: File): AnyFrame {
174175
return Files.newByteChannel(file.toPath()).use { readArrow(it) }
175176
}
176177

177-
public fun DataFrame.Companion.readArrow(url: URL): AnyFrame {
178+
public fun DataFrame.Companion.readArrow(stream: InputStream): AnyFrame = Channels.newChannel(stream).use { readArrow(it) }
179+
180+
public fun DataFrame.Companion.readArrow(url: URL): AnyFrame =
178181
when {
179-
setOf("http", "https", "ftp").any { url.protocol == it } -> {
180-
url.openStream().use { stream ->
181-
Channels.newChannel(stream).use { channel ->
182-
return readArrow(channel)
183-
}
184-
}
185-
}
186-
setOf("file").any { url.protocol == it } -> {
187-
return readArrow(File(url.path))
188-
}
182+
url.isFile() -> readArrow(url.asFile())
183+
url.isProtocolSupported() -> url.openStream().use { readArrow(it) }
189184
else -> {
190-
throw IllegalArgumentException("invalid protocol for url $url")
185+
throw IllegalArgumentException("Invalid protocol for url $url")
191186
}
192187
}
193-
}
194188

195189
public fun DataFrame.Companion.readArrow(path: String): AnyFrame = if (path.isURL()) {
196190
readArrow(URL(path))

src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/common.kt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
66
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
77
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
88
import org.jetbrains.kotlinx.dataframe.impl.columns.guessColumnType
9+
import java.io.File
910
import java.io.IOException
1011
import java.io.InputStream
1112
import java.net.URL
1213

1314
internal fun catchHttpResponse(url: URL, body: (InputStream) -> AnyFrame): AnyFrame {
1415
try {
15-
val stream = url.openStream()
16-
return body(stream)
16+
return url.openStream().use(body)
1717
} catch (e: IOException) {
1818
if (e.message?.startsWith("Server returned HTTP response code") == true) {
1919
val (_, response, _) = url.toString().httpGet().responseString()
@@ -51,3 +51,11 @@ public fun <T> List<List<T>>.toDataFrame(containsColumns: Boolean = false): AnyF
5151
}
5252

5353
internal fun String.isURL(): Boolean = listOf("http:", "https:", "ftp:").any { startsWith(it) }
54+
55+
internal fun URL.isFile(): Boolean = protocol == "file"
56+
57+
internal fun URL.asFileOrNull(): File? = if (isFile()) File(path) else null
58+
59+
internal fun URL.asFile(): File = asFileOrNull()!!
60+
61+
internal fun URL.isProtocolSupported(): Boolean = protocol in setOf("http", "https", "ftp")

src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/csv.kt

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ public fun DataFrame.Companion.readDelimStr(
5353
colTypes: Map<String, ColType> = mapOf(),
5454
skipLines: Int = 0,
5555
readLines: Int? = null
56-
): DataFrame<*> = readDelim(StringReader(text), CSVType.DEFAULT.format.withHeader(), colTypes, skipLines, readLines)
56+
): DataFrame<*> = StringReader(text).use { readDelim(it, CSVType.DEFAULT.format.withHeader(), colTypes, skipLines, readLines) }
5757

5858
public fun DataFrame.Companion.read(
5959
fileOrUrl: String,
6060
delimiter: Char,
61-
headers: List<String> = listOf(),
61+
header: List<String> = listOf(),
6262
colTypes: Map<String, ColType> = mapOf(),
6363
skipLines: Int = 0,
6464
readLines: Int? = null,
@@ -68,7 +68,7 @@ public fun DataFrame.Companion.read(
6868
catchHttpResponse(asURL(fileOrUrl)) {
6969
readDelim(
7070
it, delimiter,
71-
headers, isCompressed(fileOrUrl),
71+
header, isCompressed(fileOrUrl),
7272
getCSVType(fileOrUrl), colTypes,
7373
skipLines, readLines,
7474
duplicate, charset
@@ -78,7 +78,7 @@ public fun DataFrame.Companion.read(
7878
public fun DataFrame.Companion.readCSV(
7979
fileOrUrl: String,
8080
delimiter: Char = ',',
81-
headers: List<String> = listOf(),
81+
header: List<String> = listOf(),
8282
colTypes: Map<String, ColType> = mapOf(),
8383
skipLines: Int = 0,
8484
readLines: Int? = null,
@@ -89,7 +89,7 @@ public fun DataFrame.Companion.readCSV(
8989
catchHttpResponse(asURL(fileOrUrl)) {
9090
readDelim(
9191
it, delimiter,
92-
headers, isCompressed(fileOrUrl),
92+
header, isCompressed(fileOrUrl),
9393
CSVType.DEFAULT, colTypes,
9494
skipLines, readLines,
9595
duplicate, charset,
@@ -100,7 +100,7 @@ public fun DataFrame.Companion.readCSV(
100100
public fun DataFrame.Companion.readCSV(
101101
file: File,
102102
delimiter: Char = ',',
103-
headers: List<String> = listOf(),
103+
header: List<String> = listOf(),
104104
colTypes: Map<String, ColType> = mapOf(),
105105
skipLines: Int = 0,
106106
readLines: Int? = null,
@@ -110,7 +110,7 @@ public fun DataFrame.Companion.readCSV(
110110
): DataFrame<*> =
111111
readDelim(
112112
FileInputStream(file), delimiter,
113-
headers, isCompressed(file),
113+
header, isCompressed(file),
114114
CSVType.DEFAULT, colTypes,
115115
skipLines, readLines,
116116
duplicate, charset,
@@ -120,23 +120,43 @@ public fun DataFrame.Companion.readCSV(
120120
public fun DataFrame.Companion.readCSV(
121121
url: URL,
122122
delimiter: Char = ',',
123-
headers: List<String> = listOf(),
123+
header: List<String> = listOf(),
124124
colTypes: Map<String, ColType> = mapOf(),
125125
skipLines: Int = 0,
126126
readLines: Int? = null,
127127
duplicate: Boolean = true,
128128
charset: Charset = Charsets.UTF_8,
129129
parserOptions: ParserOptions? = null
130130
): DataFrame<*> =
131-
readDelim(
131+
readCSV(
132132
url.openStream(), delimiter,
133-
headers, isCompressed(url),
134-
CSVType.DEFAULT, colTypes,
133+
header, isCompressed(url),
134+
colTypes,
135135
skipLines, readLines,
136136
duplicate, charset,
137137
parserOptions
138138
)
139139

140+
public fun DataFrame.Companion.readCSV(
141+
stream: InputStream,
142+
delimiter: Char = ',',
143+
header: List<String> = listOf(),
144+
isCompressed: Boolean = false,
145+
colTypes: Map<String, ColType> = mapOf(),
146+
skipLines: Int = 0,
147+
readLines: Int? = null,
148+
duplicate: Boolean = true,
149+
charset: Charset = Charsets.UTF_8,
150+
parserOptions: ParserOptions? = null
151+
): DataFrame<*> = readDelim(
152+
stream, delimiter,
153+
header, isCompressed,
154+
CSVType.DEFAULT, colTypes,
155+
skipLines, readLines,
156+
duplicate, charset,
157+
parserOptions
158+
)
159+
140160
private fun getCSVType(path: String): CSVType =
141161
when (path.substringAfterLast('.').toLowerCase()) {
142162
"csv" -> CSVType.DEFAULT
@@ -160,13 +180,13 @@ internal fun asURL(fileOrUrl: String): URL = (
160180
}
161181
).toURL()
162182

163-
private fun getFormat(type: CSVType, delimiter: Char, headers: List<String>, duplicate: Boolean): CSVFormat =
164-
type.format.withDelimiter(delimiter).withHeader(*headers.toTypedArray()).withAllowDuplicateHeaderNames(duplicate)
183+
private fun getFormat(type: CSVType, delimiter: Char, header: List<String>, duplicate: Boolean): CSVFormat =
184+
type.format.withDelimiter(delimiter).withHeader(*header.toTypedArray()).withAllowDuplicateHeaderNames(duplicate)
165185

166186
public fun DataFrame.Companion.readDelim(
167187
inStream: InputStream,
168188
delimiter: Char = ',',
169-
headers: List<String> = listOf(),
189+
header: List<String> = listOf(),
170190
isCompressed: Boolean = false,
171191
csvType: CSVType,
172192
colTypes: Map<String, ColType> = mapOf(),
@@ -181,7 +201,14 @@ public fun DataFrame.Companion.readDelim(
181201
} else {
182202
BufferedReader(InputStreamReader(inStream, charset))
183203
}.run {
184-
readDelim(this, getFormat(csvType, delimiter, headers, duplicate), colTypes, skipLines, readLines, parserOptions)
204+
readDelim(
205+
this,
206+
getFormat(csvType, delimiter, header, duplicate),
207+
colTypes,
208+
skipLines,
209+
readLines,
210+
parserOptions
211+
)
185212
}
186213

187214
public enum class ColType {
@@ -222,47 +249,46 @@ public fun DataFrame.Companion.readDelim(
222249
repeat(skipLines) { reader.readLine() }
223250
}
224251

225-
format.parse(reader).use { csvParser ->
226-
val records = if (readLines == null) {
227-
csvParser.records
228-
} else {
229-
require(readLines >= 0) { "`readLines` must not be negative" }
230-
val records = ArrayList<CSVRecord>(readLines)
231-
val iter = csvParser.iterator()
232-
var count = readLines ?: 0
233-
while (iter.hasNext() && 0 < count--) {
234-
records.add(iter.next())
235-
}
236-
records
252+
val csvParser = format.parse(reader)
253+
val records = if (readLines == null) {
254+
csvParser.records
255+
} else {
256+
require(readLines >= 0) { "`readLines` must not be negative" }
257+
val records = ArrayList<CSVRecord>(readLines)
258+
val iter = csvParser.iterator()
259+
var count = readLines ?: 0
260+
while (iter.hasNext() && 0 < count--) {
261+
records.add(iter.next())
237262
}
263+
records
264+
}
238265

239-
val columnNames = csvParser.headerNames.takeIf { it.isNotEmpty() }
240-
?: (1..records[0].count()).map { index -> "X$index" }
266+
val columnNames = csvParser.headerNames.takeIf { it.isNotEmpty() }
267+
?: (1..records[0].count()).map { index -> "X$index" }
241268

242-
val generator = ColumnNameGenerator()
243-
val uniqueNames = columnNames.map { generator.addUnique(it) }
269+
val generator = ColumnNameGenerator()
270+
val uniqueNames = columnNames.map { generator.addUnique(it) }
244271

245-
val cols = uniqueNames.mapIndexed { colIndex, colName ->
246-
val defaultColType = colTypes[".default"]
247-
val colType = colTypes[colName] ?: defaultColType
248-
var hasNulls = false
249-
val values = records.map {
250-
it[colIndex].ifEmpty {
251-
hasNulls = true
252-
null
253-
}
272+
val cols = uniqueNames.mapIndexed { colIndex, colName ->
273+
val defaultColType = colTypes[".default"]
274+
val colType = colTypes[colName] ?: defaultColType
275+
var hasNulls = false
276+
val values = records.map {
277+
it[colIndex].ifEmpty {
278+
hasNulls = true
279+
null
254280
}
255-
val column = DataColumn.createValueColumn(colName, values, typeOf<String>().withNullability(hasNulls))
256-
when (colType) {
257-
null -> column.tryParse(parserOptions)
258-
else -> {
259-
val parser = Parsers[colType.toType()]!!
260-
column.parse(parser, parserOptions)
261-
}
281+
}
282+
val column = DataColumn.createValueColumn(colName, values, typeOf<String>().withNullability(hasNulls))
283+
when (colType) {
284+
null -> column.tryParse(parserOptions)
285+
else -> {
286+
val parser = Parsers[colType.toType()]!!
287+
column.parse(parser, parserOptions)
262288
}
263289
}
264-
return cols.toDataFrame()
265290
}
291+
return cols.toDataFrame()
266292
}
267293

268294
public fun AnyFrame.writeCSV(file: File, format: CSVFormat = CSVFormat.DEFAULT.withHeader()): Unit =

0 commit comments

Comments
 (0)