Skip to content

Commit 9505424

Browse files
committed
[Compiler plugin] Rework toDataFrame implementation
Interpreters need an ability to pass arguments down to DSL, so introduce new "dsl" factory function
1 parent ed4e6f0 commit 9505424

File tree

6 files changed

+63
-62
lines changed

6 files changed

+63
-62
lines changed

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,19 @@
55

66
package org.jetbrains.kotlinx.dataframe.plugin
77

8-
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
9-
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
10-
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TraverseConfiguration
11-
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
12-
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.toDataFrame
13-
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
14-
import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression
158
import org.jetbrains.kotlin.fir.expressions.FirExpression
169
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
17-
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
18-
import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList
1910
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
2011
import org.jetbrains.kotlin.fir.types.ConeKotlinType
2112
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
2213
import org.jetbrains.kotlin.fir.types.classId
2314
import org.jetbrains.kotlin.fir.types.resolvedType
2415
import org.jetbrains.kotlin.name.Name
25-
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.CreateDataFrameDslImplApproximation
16+
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
17+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
2618
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
19+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
20+
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
2721

2822
fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): CallResult? {
2923
val callReturnType = call.resolvedType
@@ -38,44 +32,6 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In
3832

3933
val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name ->
4034
when (name) {
41-
"toDataFrameDsl" -> {
42-
val list = call.argumentList as FirResolvedArgumentList
43-
val lambda = (list.arguments.singleOrNull() as? FirAnonymousFunctionExpression)?.anonymousFunction
44-
val statements = lambda?.body?.statements
45-
if (statements != null) {
46-
val receiver = CreateDataFrameDslImplApproximation()
47-
statements.filterIsInstance<FirFunctionCall>().forEach {
48-
val schemaProcessor = it.loadInterpreter() ?: return@forEach
49-
interpret(
50-
it,
51-
schemaProcessor,
52-
mapOf("dsl" to Interpreter.Success(receiver), "call" to Interpreter.Success(call)),
53-
reporter
54-
)
55-
}
56-
PluginDataFrameSchema(receiver.columns)
57-
} else {
58-
PluginDataFrameSchema(emptyList())
59-
}
60-
}
61-
"toDataFrame" -> {
62-
val list = call.argumentList as FirResolvedArgumentList
63-
val argument = list.mapping.entries.firstOrNull { it.value.name == Name.identifier("maxDepth") }?.key
64-
val maxDepth = when (argument) {
65-
null -> 0
66-
is FirLiteralExpression -> (argument.value as Number).toInt()
67-
else -> null
68-
}
69-
if (maxDepth != null) {
70-
toDataFrame(maxDepth, call, TraverseConfiguration())
71-
} else {
72-
PluginDataFrameSchema(emptyList())
73-
}
74-
}
75-
"toDataFrameDefault" -> {
76-
val maxDepth = 0
77-
toDataFrame(maxDepth, call, TraverseConfiguration())
78-
}
7935
"Aggregate" -> {
8036
val groupByCall = call.explicitReceiver as? FirFunctionCall
8137
val interpreter = groupByCall?.loadInterpreter(session)

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/ExpectedArgumentDelegates.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,11 @@ fun <T> AbstractInterpreter<T>.kproperty(
7070

7171
internal fun <T> AbstractInterpreter<T>.string(
7272
name: ArgumentName? = null
73-
): ExpectedArgumentProvider<String
74-
> = arg(name, lens = Interpreter.Value)
73+
): ExpectedArgumentProvider<String> =
74+
arg(name, lens = Interpreter.Value)
75+
76+
internal fun <T> AbstractInterpreter<T>.dsl(
77+
name: ArgumentName? = null
78+
): ExpectedArgumentProvider<(Any, Map<String, Interpreter.Success<Any?>>) -> Unit> =
79+
arg(name, lens = Interpreter.Dsl, defaultValue = Present(value = {_, _ -> }))
80+

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/add.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
99
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
1010
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
1111
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
12+
import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl
1213
import org.jetbrains.kotlinx.dataframe.plugin.impl.string
1314
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
1415

@@ -48,11 +49,11 @@ class AddDslApproximation(val columns: MutableList<SimpleCol>)
4849

4950
class AddWithDsl : AbstractSchemaModificationInterpreter() {
5051
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
51-
val Arguments.body: (Any) -> Unit by arg(lens = Interpreter.Dsl)
52+
val Arguments.body by dsl()
5253

5354
override fun Arguments.interpret(): PluginDataFrameSchema {
5455
val addDsl = AddDslApproximation(receiver.columns().toMutableList())
55-
body(addDsl)
56+
body(addDsl, emptyMap())
5657
return PluginDataFrameSchema(addDsl.columns)
5758
}
5859
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import org.jetbrains.kotlin.fir.declarations.utils.effectiveVisibility
88
import org.jetbrains.kotlin.fir.declarations.utils.isEnumClass
99
import org.jetbrains.kotlin.fir.declarations.utils.isStatic
1010
import org.jetbrains.kotlin.fir.expressions.FirCallableReferenceAccess
11-
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
11+
import org.jetbrains.kotlin.fir.expressions.FirExpression
1212
import org.jetbrains.kotlin.fir.expressions.FirGetClassCall
1313
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
1414
import org.jetbrains.kotlin.fir.java.JavaTypeParameterStack
@@ -46,6 +46,7 @@ import org.jetbrains.kotlin.name.Name
4646
import org.jetbrains.kotlin.name.StandardClassIds
4747
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
4848
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
49+
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
4950
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
5051
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
5152
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
@@ -54,26 +55,56 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
5455
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
5556
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
5657
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
58+
import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl
5759
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
5860
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
5961
import java.util.*
6062

63+
class ToDataFrameDsl : AbstractSchemaModificationInterpreter() {
64+
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
65+
val Arguments.body by dsl()
66+
override fun Arguments.interpret(): PluginDataFrameSchema {
67+
val dsl = CreateDataFrameDslImplApproximation()
68+
body(dsl, mapOf("explicitReceiver" to Interpreter.Success(receiver)))
69+
return PluginDataFrameSchema(dsl.columns)
70+
}
71+
}
72+
73+
class ToDataFrame : AbstractSchemaModificationInterpreter() {
74+
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
75+
val Arguments.maxDepth: Number by arg(defaultValue = Present(DEFAULT_MAX_DEPTH))
76+
77+
override fun Arguments.interpret(): PluginDataFrameSchema {
78+
return toDataFrame(maxDepth.toInt(), receiver, TraverseConfiguration())
79+
}
80+
}
81+
82+
class ToDataFrameDefault : AbstractSchemaModificationInterpreter() {
83+
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
84+
85+
override fun Arguments.interpret(): PluginDataFrameSchema {
86+
return toDataFrame(DEFAULT_MAX_DEPTH, receiver, TraverseConfiguration())
87+
}
88+
}
89+
90+
private const val DEFAULT_MAX_DEPTH = 0
91+
6192
class Properties0 : AbstractInterpreter<Unit>() {
6293
val Arguments.dsl: CreateDataFrameDslImplApproximation by arg()
63-
val Arguments.call: FirFunctionCall by arg()
94+
val Arguments.explicitReceiver: FirExpression? by arg()
6495
val Arguments.maxDepth: Int by arg()
65-
val Arguments.body: (Any) -> Unit by arg(lens = Interpreter.Dsl, defaultValue = Present(value = {}))
96+
val Arguments.body by dsl()
6697

6798
override fun Arguments.interpret() {
6899
dsl.configuration.maxDepth = maxDepth
69-
body(dsl.configuration.traverseConfiguration)
70-
val schema = toDataFrame(dsl.configuration.maxDepth, call, dsl.configuration.traverseConfiguration)
100+
body(dsl.configuration.traverseConfiguration, emptyMap())
101+
val schema = toDataFrame(dsl.configuration.maxDepth, explicitReceiver, dsl.configuration.traverseConfiguration)
71102
dsl.columns.addAll(schema.columns())
72103
}
73104
}
74105

75106
class CreateDataFrameConfiguration {
76-
var maxDepth = 0
107+
var maxDepth = DEFAULT_MAX_DEPTH
77108
var traverseConfiguration: TraverseConfiguration = TraverseConfiguration()
78109
}
79110

@@ -123,7 +154,7 @@ class Exclude1 : AbstractInterpreter<Unit>() {
123154
@OptIn(SymbolInternals::class)
124155
internal fun KotlinTypeFacade.toDataFrame(
125156
maxDepth: Int,
126-
call: FirFunctionCall,
157+
explicitReceiver: FirExpression?,
127158
traverseConfiguration: TraverseConfiguration
128159
): PluginDataFrameSchema {
129160
fun ConeKotlinType.isValueType() =
@@ -238,7 +269,7 @@ internal fun KotlinTypeFacade.toDataFrame(
238269
}
239270
}
240271

241-
val receiver = call.explicitReceiver ?: return PluginDataFrameSchema(emptyList())
272+
val receiver = explicitReceiver ?: return PluginDataFrameSchema(emptyList())
242273
val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema(emptyList())
243274
return when {
244275
arg.isStarProjection -> PluginDataFrameSchema(emptyList())

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ fun <T> KotlinTypeFacade.interpret(
206206
}
207207

208208
is Interpreter.Dsl -> {
209-
{ receiver: Any ->
209+
{ receiver: Any, dslArguments: Map<String, Interpreter.Success<Any?>> ->
210+
val map = mapOf("dsl" to Interpreter.Success(receiver)) + dslArguments
210211
(it.expression as FirAnonymousFunctionExpression)
211212
.anonymousFunction.body!!
212213
.statements.filterIsInstance<FirFunctionCall>()
@@ -215,7 +216,7 @@ fun <T> KotlinTypeFacade.interpret(
215216
interpret(
216217
call,
217218
schemaProcessor,
218-
mapOf("dsl" to Interpreter.Success(receiver)),
219+
map,
219220
reporter
220221
)
221222
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsAtAnyDepth0
7272
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
7373
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
7474
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
75+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrame
76+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDefault
77+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDsl
7578
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom
7679

7780
internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
@@ -171,6 +174,9 @@ internal inline fun <reified T> String.load(): T {
171174
"ColsOf1" -> ColsOf1()
172175
"ColsAtAnyDepth0" -> ColsAtAnyDepth0()
173176
"FrameCols0" -> FrameCols0()
177+
"toDataFrameDsl" -> ToDataFrameDsl()
178+
"toDataFrame" -> ToDataFrame()
179+
"toDataFrameDefault" -> ToDataFrameDefault()
174180
else -> error("$this")
175181
} as T
176182
}

0 commit comments

Comments
 (0)