Skip to content

Commit 48247e1

Browse files
committed
[Compiler plugin] Support colsOf, colsAtAnyDepth, frameCols
1 parent 2b5fe8b commit 48247e1

File tree

10 files changed

+130
-8
lines changed

10 files changed

+130
-8
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/all.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
1010
import org.jetbrains.kotlinx.dataframe.DataRow
1111
import org.jetbrains.kotlinx.dataframe.Predicate
1212
import org.jetbrains.kotlinx.dataframe.RowFilter
13+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
1314
import org.jetbrains.kotlinx.dataframe.api.AllColumnsSelectionDsl.CommonAllSubsetDocs.BehaviorArg
1415
import org.jetbrains.kotlinx.dataframe.api.AllColumnsSelectionDsl.CommonAllSubsetDocs.ColumnDoesNotExistArg
1516
import org.jetbrains.kotlinx.dataframe.api.AllColumnsSelectionDsl.CommonAllSubsetDocs.ExampleArg
@@ -300,6 +301,7 @@ public interface AllColumnsSelectionDsl<out _UNUSED> {
300301
*
301302
* `df.`[select][DataFrame.select]` { `[all][ColumnsSelectionDsl.all]`() }`
302303
*/
304+
@Interpretable("All0")
303305
public fun ColumnsSelectionDsl<*>.all(): TransformableColumnSet<*> =
304306
asSingleColumn().allColumnsInternal()
305307

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/colsAtAnyDepth.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnFilter
44
import org.jetbrains.kotlinx.dataframe.DataColumn
55
import org.jetbrains.kotlinx.dataframe.DataFrame
66
import org.jetbrains.kotlinx.dataframe.DataRow
7+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
78
import org.jetbrains.kotlinx.dataframe.api.ColsAtAnyDepthColumnsSelectionDsl.Grammar
89
import org.jetbrains.kotlinx.dataframe.api.ColsAtAnyDepthColumnsSelectionDsl.Grammar.ColumnGroupName
910
import org.jetbrains.kotlinx.dataframe.api.ColsAtAnyDepthColumnsSelectionDsl.Grammar.ColumnSetName
@@ -138,6 +139,7 @@ public interface ColsAtAnyDepthColumnsSelectionDsl {
138139
*
139140
* `df.`[select][DataFrame.select]` { `[colsAtAnyDepth][ColumnsSelectionDsl.colsAtAnyDepth]` { !it.`[isColumnGroup][DataColumn.isColumnGroup]` } }`
140141
*/
142+
@Interpretable("ColsAtAnyDepth0")
141143
public fun ColumnsSelectionDsl<*>.colsAtAnyDepth(predicate: ColumnFilter<*> = { true }): ColumnSet<*> =
142144
asSingleColumn().colsAtAnyDepthInternal(predicate)
143145

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/colsOf.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.ColumnFilter
44
import org.jetbrains.kotlinx.dataframe.DataColumn
55
import org.jetbrains.kotlinx.dataframe.DataFrame
66
import org.jetbrains.kotlinx.dataframe.DataRow
7+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
78
import org.jetbrains.kotlinx.dataframe.api.ColsOfColumnsSelectionDsl.Grammar
89
import org.jetbrains.kotlinx.dataframe.api.ColsOfColumnsSelectionDsl.Grammar.ColumnGroupName
910
import org.jetbrains.kotlinx.dataframe.api.ColsOfColumnsSelectionDsl.Grammar.ColumnSetName
@@ -203,6 +204,7 @@ public fun <C> ColumnSet<*>.colsOf(
203204
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.FilterParam]
204205
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.Return]
205206
*/
207+
@Interpretable("ColsOf1")
206208
public inline fun <reified C> ColumnSet<*>.colsOf(
207209
noinline filter: ColumnFilter<C> = { true },
208210
): TransformableColumnSet<C> = colsOf(typeOf<C>(), filter)
@@ -228,6 +230,7 @@ public fun <C> ColumnsSelectionDsl<*>.colsOf(
228230
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.FilterParam]
229231
* @include [ColsOfColumnsSelectionDsl.CommonColsOfDocs.Return]
230232
*/
233+
@Interpretable("ColsOf0")
231234
public inline fun <reified C> ColumnsSelectionDsl<*>.colsOf(
232235
noinline filter: ColumnFilter<C> = { true },
233236
): TransformableColumnSet<C> = asSingleColumn().colsOf(typeOf<C>(), filter)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/frameCols.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
44
import org.jetbrains.kotlinx.dataframe.DataFrame
55
import org.jetbrains.kotlinx.dataframe.DataRow
66
import org.jetbrains.kotlinx.dataframe.Predicate
7+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
78
import org.jetbrains.kotlinx.dataframe.api.FrameColsColumnsSelectionDsl.Grammar.ColumnGroupName
89
import org.jetbrains.kotlinx.dataframe.api.FrameColsColumnsSelectionDsl.Grammar.ColumnSetName
910
import org.jetbrains.kotlinx.dataframe.api.FrameColsColumnsSelectionDsl.Grammar.PlainDslName
@@ -111,6 +112,7 @@ public interface FrameColsColumnsSelectionDsl {
111112
*
112113
* `df.`[select][DataFrame.select]` { `[frameCols][ColumnsSelectionDsl.frameCols]` { it.`[name][ColumnReference.name]`.`[startsWith][String.startsWith]`("my") } }`
113114
*/
115+
@Interpretable("FrameCols0")
114116
public fun ColumnSet<*>.frameCols(filter: Predicate<FrameColumn<*>> = { true }): TransformableColumnSet<DataFrame<*>> =
115117
frameColumnsInternal(filter)
116118

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ExpressionAnalysisAdditionalChecker.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ private class Checker(val cache: FirCache<String, PluginDataFrameSchema, KotlinT
9090
val targetProjection = expression.typeArguments.getOrNull(0) as? FirTypeProjectionWithVariance ?: return
9191
val targetType = targetProjection.typeRef.coneType as? ConeClassLikeType ?: return
9292
val target = pluginDataFrameSchema(targetType)
93-
val sourceColumns = source.flatten()
94-
val targetColumns = target.flatten()
93+
val sourceColumns = source.flatten(includeFrames = true)
94+
val targetColumns = target.flatten(includeFrames = true)
9595
val sourceMap = sourceColumns.associate { it.path.path to it.column }
9696
val missingColumns = mutableListOf<String>()
9797
var valid = true

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/flattenPluginSchema.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,24 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
88
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnPathApproximation
99
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
1010

11-
fun PluginDataFrameSchema.flatten(): List<ColumnWithPathApproximation> {
11+
fun PluginDataFrameSchema.flatten(includeFrames: Boolean): List<ColumnWithPathApproximation> {
1212
if (columns().isEmpty()) return emptyList()
1313
val columns = mutableListOf<ColumnWithPathApproximation>()
14-
flattenImpl(columns(), emptyList(), columns)
14+
flattenImpl(columns(), emptyList(), columns, includeFrames)
1515
return columns
1616
}
1717

18-
fun flattenImpl(columns: List<SimpleCol>, path: List<String>, flatList: MutableList<ColumnWithPathApproximation>) {
18+
fun flattenImpl(columns: List<SimpleCol>, path: List<String>, flatList: MutableList<ColumnWithPathApproximation>, includeFrames: Boolean) {
1919
columns.forEach { column ->
2020
val fullPath = path + listOf(column.name)
2121
when (column) {
2222
is SimpleColumnGroup -> {
2323
flatList.add(ColumnWithPathApproximation(ColumnPathApproximation(fullPath), column))
24-
flattenImpl(column.columns(), fullPath, flatList)
24+
flattenImpl(column.columns(), fullPath, flatList, includeFrames)
2525
}
2626
is SimpleFrameColumn -> {
2727
flatList.add(ColumnWithPathApproximation(ColumnPathApproximation(fullPath), column))
28-
flattenImpl(column.columns(), fullPath, flatList)
28+
flattenImpl(column.columns(), fullPath, flatList, includeFrames)
2929
}
3030
is SimpleDataColumn -> {
3131
flatList.add(ColumnWithPathApproximation(ColumnPathApproximation(fullPath), column))

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/select.kt

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package org.jetbrains.kotlinx.dataframe.plugin.impl.api
22

3+
import org.jetbrains.kotlin.fir.types.ConeKotlinType
4+
import org.jetbrains.kotlin.fir.types.isSubtypeOf
35
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
46
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
57
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
68
import org.jetbrains.kotlinx.dataframe.api.Infer
79
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
810
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
11+
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
912
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnPathApproximation
1013
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
1114
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
@@ -27,7 +30,12 @@ internal class Expr0 : AbstractInterpreter<ColumnsResolver>() {
2730
val Arguments.expression: TypeApproximation by type()
2831

2932
override fun Arguments.interpret(): ColumnsResolver {
30-
return SingleColumnApproximation(ColumnWithPathApproximation(ColumnPathApproximation(listOf(name)), SimpleDataColumn(name, expression)))
33+
return SingleColumnApproximation(
34+
ColumnWithPathApproximation(
35+
ColumnPathApproximation(listOf(name)),
36+
SimpleDataColumn(name, expression)
37+
)
38+
)
3139
}
3240
}
3341

@@ -43,3 +51,75 @@ internal class And0 : AbstractInterpreter<ColumnsResolver>() {
4351
}
4452
}
4553
}
54+
55+
internal class All0 : AbstractInterpreter<ColumnsResolver>() {
56+
override fun Arguments.interpret(): ColumnsResolver {
57+
return object : ColumnsResolver {
58+
override fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation> {
59+
return df.columns().map {
60+
val path = ColumnPathApproximation(listOf(it.name))
61+
ColumnWithPathApproximation(path, it)
62+
}
63+
}
64+
}
65+
}
66+
}
67+
68+
internal class ColsOf0 : AbstractInterpreter<ColumnsResolver>() {
69+
val Arguments.typeArg0: TypeApproximation by arg()
70+
71+
override fun Arguments.interpret(): ColumnsResolver {
72+
return object : ColumnsResolver {
73+
override fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation> {
74+
val cols = df.columns().map {
75+
val path = ColumnPathApproximation(listOf(it.name))
76+
ColumnWithPathApproximation(path, it)
77+
}
78+
return colsOf(cols, typeArg0.type)
79+
}
80+
}
81+
}
82+
83+
}
84+
85+
private fun Arguments.colsOf(cols: List<ColumnWithPathApproximation>, type: ConeKotlinType) =
86+
cols
87+
.filter {
88+
val column = it.column
89+
column is SimpleDataColumn && column.type.type.isSubtypeOf(type, session)
90+
}
91+
92+
internal class ColsAtAnyDepth0 : AbstractInterpreter<ColumnsResolver>() {
93+
override fun Arguments.interpret(): ColumnsResolver {
94+
return object : ColumnsResolver {
95+
override fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation> {
96+
return df.flatten(includeFrames = false)
97+
}
98+
}
99+
}
100+
}
101+
102+
internal class ColsOf1 : AbstractInterpreter<ColumnsResolver>() {
103+
val Arguments.receiver: ColumnsResolver by arg()
104+
val Arguments.typeArg0: TypeApproximation by arg()
105+
106+
override fun Arguments.interpret(): ColumnsResolver {
107+
return object : ColumnsResolver {
108+
override fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation> {
109+
return colsOf(receiver.resolve(df), typeArg0.type)
110+
}
111+
}
112+
}
113+
}
114+
115+
internal class FrameCols0 : AbstractInterpreter<ColumnsResolver>() {
116+
val Arguments.receiver: ColumnsResolver by arg()
117+
118+
override fun Arguments.interpret(): ColumnsResolver {
119+
return object : ColumnsResolver {
120+
override fun resolve(df: PluginDataFrameSchema): List<ColumnWithPathApproximation> {
121+
return receiver.resolve(df).filter { it.column is SimpleFrameColumn }
122+
}
123+
}
124+
}
125+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ import org.jetbrains.kotlin.fir.types.classId
6767
import org.jetbrains.kotlin.fir.types.coneType
6868
import org.jetbrains.kotlin.name.ClassId
6969
import org.jetbrains.kotlin.name.Name
70+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.All0
71+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsAtAnyDepth0
72+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
73+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
74+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
7075
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom
7176

7277
internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
@@ -161,6 +166,11 @@ internal inline fun <reified T> String.load(): T {
161166
"ReadDelimStr" -> ReadDelimStr()
162167
"GroupByToDataFrame" -> GroupByToDataFrame()
163168
"ToDataFrameFrom0" -> ToDataFrameFrom()
169+
"All0" -> All0()
170+
"ColsOf0" -> ColsOf0()
171+
"ColsOf1" -> ColsOf1()
172+
"ColsAtAnyDepth0" -> ColsAtAnyDepth0()
173+
"FrameCols0" -> FrameCols0()
164174
else -> error("$this")
165175
} as T
166176
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
data class Nested(val d: List<Double>)
7+
8+
data class Record(val a: String, val b: Int, val nested: List<Nested>)
9+
10+
fun box(): String {
11+
val df = listOf(Record("112", 42, listOf(Nested(listOf(3.0))))).toDataFrame(maxDepth = 1)
12+
13+
df.group { nested }.into("group").convert { colsAtAnyDepth().frameCols() }.with { 1 }.compareSchemas()
14+
df.group { b }.into("group").convert { colsAtAnyDepth().colsOf<Int>() }.with { "" }.compareSchemas()
15+
df.convert { all().frameCols() }.with { 1 }.compareSchemas()
16+
return "OK"
17+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ public void testSelectThis() {
310310
runTest("testData/box/selectThis.kt");
311311
}
312312

313+
@Test
314+
@TestMetadata("selectionDsl.kt")
315+
public void testSelectionDsl() {
316+
runTest("testData/box/selectionDsl.kt");
317+
}
318+
313319
@Test
314320
@TestMetadata("toDataFrame.kt")
315321
public void testToDataFrame() {

0 commit comments

Comments
 (0)