Skip to content

Commit 2d87082

Browse files
committed
Rework codegen of default read methods so SupportedFormat holds all necessary information
1 parent bd66615 commit 2d87082

File tree

17 files changed

+281
-303
lines changed

17 files changed

+281
-303
lines changed

build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dependencies {
4444
implementation(libs.poi.ooxml)
4545

4646
implementation(libs.kotlin.datetimeJvm)
47+
implementation("com.squareup:kotlinpoet:1.11.0")
4748

4849
testImplementation(libs.junit)
4950
testImplementation(libs.kotestAssertions) {

dataframe-arrow/build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ kotlinPublications {
2828
packageName.set(artifactId)
2929
}
3030
}
31+
32+
kotlin {
33+
explicitApi()
34+
}

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
3636
import org.jetbrains.kotlinx.dataframe.api.Infer
3737
import org.jetbrains.kotlinx.dataframe.api.concat
3838
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
39+
import org.jetbrains.kotlinx.dataframe.codeGen.AbstractDefaultReadMethod
40+
import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadDfMethod
3941
import java.io.File
4042
import java.io.InputStream
4143
import java.math.BigDecimal
@@ -57,8 +59,16 @@ public class ArrowFeather : SupportedFormat {
5759
override fun acceptsExtension(ext: String): Boolean = ext == "feather"
5860

5961
override val testOrder: Int = 50000
62+
63+
override fun createDefaultReadMethod(pathRepresentation: String?): DefaultReadDfMethod {
64+
return DefaultReadArrowMethod(pathRepresentation)
65+
}
6066
}
6167

68+
private const val readArrowFeather = "readArrowFeather"
69+
70+
private class DefaultReadArrowMethod(path: String?) : AbstractDefaultReadMethod(path, MethodArguments.EMPTY, readArrowFeather)
71+
6272
internal object Allocator {
6373
val ROOT by lazy {
6474
RootAllocator(Long.MAX_VALUE)
@@ -78,7 +88,7 @@ public enum class ArrowFormat() {
7888
}
7989

8090
/**
81-
* Read [ArrowFormat.IPC] data from existing [channel]
91+
* Read [ArrowFeather.IPC] data from existing [channel]
8292
*/
8393
public fun readArrowIPC(channel: ReadableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
8494
ArrowStreamReader(channel, allocator).use { reader ->
@@ -95,7 +105,7 @@ public fun readArrowIPC(channel: ReadableByteChannel, allocator: RootAllocator =
95105
}
96106

97107
/**
98-
* Read [ArrowFormat.FEATHER] data from existing [channel]
108+
* Read [ArrowFeather.FEATHER] data from existing [channel]
99109
*/
100110
public fun readArrowFeather(channel: SeekableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame {
101111
ArrowFileReader(channel, allocator).use { reader ->

plugins/dataframe-gradle-plugin/build.gradle.kts

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@ repositories {
1313

1414
group = "org.jetbrains.kotlin"
1515

16+
dependencies {
17+
implementation(project(":"))
18+
implementation(project(":dataframe-arrow"))
19+
implementation(kotlin("gradle-plugin-api"))
20+
implementation(kotlin("gradle-plugin"))
21+
implementation("com.beust:klaxon:5.5")
22+
implementation(libs.ksp.gradle)
23+
implementation(libs.ksp.api)
24+
25+
testImplementation("junit:junit:4.12")
26+
testImplementation("io.kotest:kotest-assertions-core:4.6.0")
27+
testImplementation("com.android.tools.build:gradle-api:4.1.1")
28+
testImplementation("com.android.tools.build:gradle:4.1.1")
29+
testImplementation("io.ktor:ktor-server-netty:1.6.7")
30+
testImplementation(gradleApi())
31+
}
32+
1633
tasks.withType<ProcessResources> {
1734
filesMatching("**/plugin.properties") {
1835
filter {
@@ -69,22 +86,6 @@ tasks.withType<JavaCompile>().all {
6986
targetCompatibility = JavaVersion.VERSION_1_8.toString()
7087
}
7188

72-
dependencies {
73-
implementation(project(":"))
74-
implementation(kotlin("gradle-plugin-api"))
75-
implementation(kotlin("gradle-plugin"))
76-
implementation("com.beust:klaxon:5.5")
77-
implementation(libs.ksp.gradle)
78-
implementation(libs.ksp.api)
79-
80-
testImplementation("junit:junit:4.12")
81-
testImplementation("io.kotest:kotest-assertions-core:4.6.0")
82-
testImplementation("com.android.tools.build:gradle-api:4.1.1")
83-
testImplementation("com.android.tools.build:gradle:4.1.1")
84-
testImplementation("io.ktor:ktor-server-netty:1.6.7")
85-
testImplementation(gradleApi())
86-
}
87-
8889
sourceSets {
8990
create("integrationTest") {
9091
withConvention(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::class) {

plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/ApiChangesDetectionTest.kt

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,14 @@ package org.jetbrains.dataframe.gradle
33
import io.kotest.matchers.shouldBe
44
import org.gradle.testkit.runner.TaskOutcome
55
import org.jetbrains.kotlinx.dataframe.DataFrame
6-
import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadCsvMethod
7-
import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadJsonMethod
86
import org.junit.Test
97
import java.io.File
10-
import kotlin.reflect.KClass
11-
12-
annotation class RelatedGenerator(vararg val clazz: KClass<*>)
138

149
class ApiChangesDetectionTest : AbstractDataFramePluginIntegrationTest() {
15-
@RelatedGenerator(
16-
GenerateDataSchemaTask::class,
17-
DefaultReadCsvMethod::class,
18-
DefaultReadJsonMethod::class
19-
)
10+
11+
// GenerateDataSchemaTask::class,
12+
// DefaultReadCsvMethod::class,
13+
// DefaultReadJsonMethod::class
2014
@Test
2115
fun `cast api`() {
2216
compiles {
@@ -34,10 +28,8 @@ class ApiChangesDetectionTest : AbstractDataFramePluginIntegrationTest() {
3428
}
3529
}
3630

37-
@RelatedGenerator(
38-
GenerateDataSchemaTask::class,
39-
DefaultReadJsonMethod::class
40-
)
31+
// GenerateDataSchemaTask::class,
32+
// DefaultReadJsonMethod::class
4133
@Test
4234
fun `read json api`() {
4335
compiles {
@@ -51,11 +43,8 @@ class ApiChangesDetectionTest : AbstractDataFramePluginIntegrationTest() {
5143
""".trimIndent()
5244
}
5345
}
54-
55-
@RelatedGenerator(
56-
GenerateDataSchemaTask::class,
57-
DefaultReadCsvMethod::class,
58-
)
46+
// GenerateDataSchemaTask::class,
47+
// DefaultReadCsvMethod::class,
5948
@Test
6049
fun `read csv api`() {
6150
compiles {

plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/GenerateDataSchemaTask.kt

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@ import org.gradle.api.tasks.Input
88
import org.gradle.api.tasks.OutputFile
99
import org.gradle.api.tasks.TaskAction
1010
import org.jetbrains.dataframe.impl.codeGen.CodeGenerator
11-
import org.jetbrains.kotlinx.dataframe.codeGen.CsvOptions
1211
import org.jetbrains.kotlinx.dataframe.codeGen.MarkerVisibility
1312
import org.jetbrains.kotlinx.dataframe.codeGen.NameNormalizer
1413
import org.jetbrains.kotlinx.dataframe.impl.codeGen.DfReadResult
1514
import org.jetbrains.kotlinx.dataframe.impl.codeGen.from
1615
import org.jetbrains.kotlinx.dataframe.impl.codeGen.toStandaloneSnippet
1716
import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlReader
17+
import org.jetbrains.kotlinx.dataframe.io.ArrowFeather
18+
import org.jetbrains.kotlinx.dataframe.io.CSV
19+
import org.jetbrains.kotlinx.dataframe.io.Excel
20+
import org.jetbrains.kotlinx.dataframe.io.JSON
21+
import org.jetbrains.kotlinx.dataframe.io.TSV
1822
import java.io.File
1923
import java.net.URL
2024
import java.nio.file.Paths
@@ -55,14 +59,32 @@ abstract class GenerateDataSchemaTask : DefaultTask() {
5559
@TaskAction
5660
fun generate() {
5761
val csvOptions = csvOptions.get()
58-
val (delimiter) = csvOptions
5962
val url = urlOf(data.get())
60-
val res = when (val readResult = CodeGenerator.urlReader(url, CsvOptions(delimiter))) {
63+
val formats = listOf(
64+
CSV(delimiter = csvOptions.delimiter),
65+
JSON(),
66+
Excel(),
67+
TSV(),
68+
ArrowFeather()
69+
)
70+
val res = when (val readResult = CodeGenerator.urlReader(url, formats)) {
6171
is DfReadResult.Success -> readResult
6272
is DfReadResult.Error -> throw Exception("Error while reading dataframe from data at $url", readResult.reason)
6373
}
74+
if (res.format is ArrowFeather) {
75+
val arrowDependency = project.configurations.asSequence()
76+
.mapNotNull { configuration ->
77+
configuration.allDependencies.find { it.group?.equals("org.jetbrains.kotlinx") ?: false && it.name == "dataframe-arrow" }
78+
}
79+
.firstOrNull()
80+
81+
if (arrowDependency == null) {
82+
project.logger.warn("Add dependency on \"org.jetbrains.kotlinx:dataframe-arrow\" to compile schema ${interfaceName.get()} generated from ${data.get()}")
83+
}
84+
}
6485
val codeGenerator = CodeGenerator.create(useFqNames = false)
6586
val delimiters = delimiters.get()
87+
val readDfMethod = res.getReadDfMethod(stringOf(data.get()))
6688
val codeGenResult = codeGenerator.generate(
6789
schema = res.schema,
6890
name = interfaceName.get(),
@@ -74,13 +96,13 @@ abstract class GenerateDataSchemaTask : DefaultTask() {
7496
DataSchemaVisibility.IMPLICIT_PUBLIC -> MarkerVisibility.IMPLICIT_PUBLIC
7597
DataSchemaVisibility.EXPLICIT_PUBLIC -> MarkerVisibility.EXPLICIT_PUBLIC
7698
},
77-
readDfMethod = res.getReadDfMethod(stringOf(data.get())),
99+
readDfMethod = readDfMethod,
78100
fieldNameNormalizer = NameNormalizer.from(delimiters)
79101
)
80102
val escapedPackageName = escapePackageName(packageName.get())
81103

82104
val dataSchema = dataSchema.get()
83-
dataSchema.writeText(codeGenResult.toStandaloneSnippet(escapedPackageName))
105+
dataSchema.writeText(codeGenResult.toStandaloneSnippet(escapedPackageName, readDfMethod.additionalImports))
84106
}
85107

86108
private fun stringOf(data: Any): String {

plugins/dataframe-gradle-plugin/src/test/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginTest.kt

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.jetbrains.dataframe.gradle
22

3+
import io.kotest.assertions.asClue
34
import io.kotest.inspectors.forOne
45
import io.kotest.matchers.shouldBe
56
import io.kotest.matchers.string.shouldContain
@@ -328,15 +329,17 @@ internal class SchemaGeneratorPluginTest {
328329
""".trimIndent()
329330
}
330331
result.task(":generateDataFrameTest")?.outcome shouldBe TaskOutcome.SUCCESS
331-
File(buildDir, "build/generated/dataframe/main/kotlin/org/test/Test.Generated.kt").readLines().let {
332-
it.forOne {
333-
it.shouldContain("val a")
334-
}
335-
it.forOne {
336-
it.shouldContain("val b")
337-
}
338-
it.forOne {
339-
it.shouldContain("val c")
332+
File(buildDir, "build/generated/dataframe/main/kotlin/org/test/Test.Generated.kt").asClue {
333+
it.readLines().let {
334+
it.forOne {
335+
it.shouldContain("val a")
336+
}
337+
it.forOne {
338+
it.shouldContain("val b")
339+
}
340+
it.forOne {
341+
it.shouldContain("val c")
342+
}
340343
}
341344
}
342345
}

plugins/symbol-processor/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ group = "org.jetbrains.kotlinx.dataframe"
1313

1414
dependencies {
1515
implementation(project(":"))
16+
implementation(project(":dataframe-arrow"))
1617
implementation(libs.ksp.api)
1718
testImplementation("org.jetbrains.kotlin:kotlin-test")
1819
testImplementation("com.github.tschuchortdev:kotlin-compile-testing:1.4.4")

plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@ import com.google.devtools.ksp.processing.KSPLogger
77
import com.google.devtools.ksp.processing.Resolver
88
import com.google.devtools.ksp.symbol.KSFile
99
import org.jetbrains.dataframe.impl.codeGen.CodeGenerator
10+
import org.jetbrains.kotlinx.dataframe.annotations.CsvOptions
1011
import org.jetbrains.kotlinx.dataframe.annotations.DataSchemaVisibility
1112
import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema
1213
import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchemaByAbsolutePath
13-
import org.jetbrains.kotlinx.dataframe.codeGen.CsvOptions
1414
import org.jetbrains.kotlinx.dataframe.codeGen.MarkerVisibility
1515
import org.jetbrains.kotlinx.dataframe.codeGen.NameNormalizer
1616
import org.jetbrains.kotlinx.dataframe.impl.codeGen.DfReadResult
1717
import org.jetbrains.kotlinx.dataframe.impl.codeGen.from
1818
import org.jetbrains.kotlinx.dataframe.impl.codeGen.toStandaloneSnippet
1919
import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlReader
20+
import org.jetbrains.kotlinx.dataframe.io.ArrowFeather
21+
import org.jetbrains.kotlinx.dataframe.io.CSV
22+
import org.jetbrains.kotlinx.dataframe.io.Excel
23+
import org.jetbrains.kotlinx.dataframe.io.JSON
24+
import org.jetbrains.kotlinx.dataframe.io.TSV
2025
import java.io.File
2126
import java.net.MalformedURLException
2227
import java.net.URL
@@ -92,7 +97,7 @@ class DataSchemaGenerator(
9297
visibility.toMarkerVisibility(),
9398
normalizationDelimiters.toList(),
9499
withDefaultPath,
95-
CsvOptions(csvOptions.delimiter)
100+
csvOptions
96101
)
97102
}
98103

@@ -111,7 +116,7 @@ class DataSchemaGenerator(
111116
visibility.toMarkerVisibility(),
112117
normalizationDelimiters.toList(),
113118
withDefaultPath,
114-
CsvOptions(csvOptions.delimiter)
119+
csvOptions
115120
)
116121
}
117122

@@ -133,18 +138,28 @@ class DataSchemaGenerator(
133138
fun generateDataSchema(importStatement: ImportDataSchemaStatement) {
134139
val packageName = importStatement.origin.packageName.asString()
135140
val name = importStatement.name
136-
val csvOptions = CsvOptions(importStatement.csvOptions.delimiter)
137141
val schemaFile =
138142
codeGenerator.createNewFile(Dependencies(true, importStatement.origin), packageName, "$name.Generated")
139143

140-
val parsedDf = when (val readResult = CodeGenerator.urlReader(importStatement.dataSource.data, csvOptions)) {
144+
val formats = listOf(
145+
CSV(delimiter = importStatement.csvOptions.delimiter),
146+
JSON(),
147+
Excel(),
148+
TSV(),
149+
ArrowFeather()
150+
)
151+
152+
val parsedDf = when (val readResult = CodeGenerator.urlReader(importStatement.dataSource.data, formats)) {
141153
is DfReadResult.Success -> readResult
142154
is DfReadResult.Error -> {
143155
logger.error("Error while reading dataframe from data at ${importStatement.dataSource.pathRepresentation}: ${readResult.reason}")
144156
return
145157
}
146158
}
147159
val codeGenerator = CodeGenerator.create(useFqNames = false)
160+
161+
val readDfMethod =
162+
parsedDf.getReadDfMethod(importStatement.dataSource.pathRepresentation.takeIf { importStatement.withDefaultPath })
148163
val codeGenResult = codeGenerator.generate(
149164
parsedDf.schema,
150165
name,
@@ -153,10 +168,10 @@ class DataSchemaGenerator(
153168
isOpen = true,
154169
importStatement.visibility,
155170
emptyList(),
156-
parsedDf.getReadDfMethod(importStatement.dataSource.pathRepresentation.takeIf { importStatement.withDefaultPath }),
171+
readDfMethod,
157172
NameNormalizer.from(importStatement.normalizationDelimiters.toSet())
158173
)
159-
val code = codeGenResult.toStandaloneSnippet(packageName)
174+
val code = codeGenResult.toStandaloneSnippet(packageName, readDfMethod.additionalImports)
160175
schemaFile.bufferedWriter().use {
161176
it.write(code)
162177
}

0 commit comments

Comments
 (0)