Skip to content

Commit bc4a236

Browse files
authored
Merge pull request #908 from Kotlin/dataFrameOfPairs
[Compiler plugin] Support dataFrameOf(Pair<String, List<T>)
2 parents 0ea2488 + 3596f05 commit bc4a236

File tree

9 files changed

+146
-45
lines changed

9 files changed

+146
-45
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/constructors.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ public inline fun <reified C> dataFrameOf(vararg header: String, fill: (String)
279279

280280
public fun dataFrameOf(header: Iterable<String>): DataFrameBuilder = DataFrameBuilder(header.asList())
281281

282+
@Refine
283+
@Interpretable("DataFrameOf3")
282284
public fun dataFrameOf(vararg columns: Pair<String, List<Any?>>): DataFrame<*> =
283285
columns.map { it.second.toColumn(it.first, Infer.Type) }.toDataFrame()
284286

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class FunctionCallTransformer(
190190
val tokenFir = token.toClassSymbol(session)!!.fir
191191
tokenFir.callShapeData = CallShapeData.RefinedType(dataSchemaApis.map { it.scope.symbol })
192192

193-
return buildLetCall(call, originalSymbol, dataSchemaApis, listOf(tokenFir))
193+
return buildScopeFunctionCall(call, originalSymbol, dataSchemaApis, listOf(tokenFir))
194194
}
195195
}
196196

@@ -253,7 +253,7 @@ class FunctionCallTransformer(
253253
val keyToken = groupMarker.toClassSymbol(session)!!.fir
254254
keyToken.callShapeData = CallShapeData.RefinedType(groupApis.map { it.scope.symbol })
255255

256-
return buildLetCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken))
256+
return buildScopeFunctionCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken))
257257
}
258258
}
259259

@@ -305,18 +305,17 @@ class FunctionCallTransformer(
305305
private fun Name.asTokenName() = identifierOrNullIfSpecial?.titleCase() ?: DEFAULT_NAME
306306

307307
@OptIn(SymbolInternals::class)
308-
private fun buildLetCall(
308+
private fun buildScopeFunctionCall(
309309
call: FirFunctionCall,
310310
originalSymbol: FirNamedFunctionSymbol,
311311
dataSchemaApis: List<DataSchemaApi>,
312312
additionalDeclarations: List<FirClass>
313313
): FirFunctionCall {
314314

315-
val explicitReceiver = call.explicitReceiver ?: return call
316-
val receiverType = explicitReceiver.resolvedType
315+
val explicitReceiver = call.explicitReceiver
316+
val receiverType = explicitReceiver?.resolvedType
317317
val returnType = call.resolvedType
318-
val resolvedLet = findLet()
319-
val parameter = resolvedLet.valueParameterSymbols[0]
318+
val scopeFunction = if (explicitReceiver != null) findLet() else findRun()
320319

321320
// original call is inserted later
322321
call.transformCalleeReference(object : FirTransformer<Nothing?>() {
@@ -350,20 +349,23 @@ class FunctionCallTransformer(
350349
returnTypeRef = buildResolvedTypeRef {
351350
type = returnType
352351
}
353-
val itName = Name.identifier("it")
354-
val parameterSymbol = FirValueParameterSymbol(itName)
355-
valueParameters += buildValueParameter {
356-
moduleData = session.moduleData
357-
origin = FirDeclarationOrigin.Source
358-
returnTypeRef = buildResolvedTypeRef {
359-
type = receiverType
352+
val parameterSymbol = receiverType?.let {
353+
val itName = Name.identifier("it")
354+
val parameterSymbol = FirValueParameterSymbol(itName)
355+
valueParameters += buildValueParameter {
356+
moduleData = session.moduleData
357+
origin = FirDeclarationOrigin.Source
358+
returnTypeRef = buildResolvedTypeRef {
359+
type = receiverType
360+
}
361+
this.name = itName
362+
this.symbol = parameterSymbol
363+
containingFunctionSymbol = fSymbol
364+
isCrossinline = false
365+
isNoinline = false
366+
isVararg = false
360367
}
361-
this.name = itName
362-
this.symbol = parameterSymbol
363-
containingFunctionSymbol = fSymbol
364-
isCrossinline = false
365-
isNoinline = false
366-
isVararg = false
368+
parameterSymbol
367369
}
368370
body = buildBlock {
369371
this.coneTypeOrNull = returnType
@@ -375,20 +377,23 @@ class FunctionCallTransformer(
375377
statements += additionalDeclarations
376378

377379
statements += buildReturnExpression {
378-
val itPropertyAccess = buildPropertyAccessExpression {
379-
coneTypeOrNull = receiverType
380-
calleeReference = buildResolvedNamedReference {
381-
name = itName
382-
resolvedSymbol = parameterSymbol
380+
if (parameterSymbol != null) {
381+
val itPropertyAccess = buildPropertyAccessExpression {
382+
coneTypeOrNull = receiverType
383+
calleeReference = buildResolvedNamedReference {
384+
name = parameterSymbol.name
385+
resolvedSymbol = parameterSymbol
386+
}
387+
}
388+
if (callDispatchReceiver != null) {
389+
call.replaceDispatchReceiver(itPropertyAccess)
390+
}
391+
call.replaceExplicitReceiver(itPropertyAccess)
392+
if (callExtensionReceiver != null) {
393+
call.replaceExtensionReceiver(itPropertyAccess)
383394
}
384395
}
385-
if (callDispatchReceiver != null) {
386-
call.replaceDispatchReceiver(itPropertyAccess)
387-
}
388-
call.replaceExplicitReceiver(itPropertyAccess)
389-
if (callExtensionReceiver != null) {
390-
call.replaceExtensionReceiver(itPropertyAccess)
391-
}
396+
392397
result = call
393398
this.target = target
394399
}
@@ -397,11 +402,19 @@ class FunctionCallTransformer(
397402
isLambda = true
398403
hasExplicitParameterList = false
399404
typeRef = buildResolvedTypeRef {
400-
type = ConeClassLikeTypeImpl(
401-
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function1"))),
402-
typeArguments = arrayOf(receiverType, returnType),
403-
isNullable = false
404-
)
405+
type = if (receiverType != null) {
406+
ConeClassLikeTypeImpl(
407+
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function1"))),
408+
typeArguments = arrayOf(receiverType, returnType),
409+
isNullable = false
410+
)
411+
} else {
412+
ConeClassLikeTypeImpl(
413+
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function0"))),
414+
typeArguments = arrayOf(returnType),
415+
isNullable = false
416+
)
417+
}
405418
}
406419
invocationKind = EventOccurrencesRange.EXACTLY_ONCE
407420
inlineStatus = InlineStatus.Inline
@@ -413,11 +426,13 @@ class FunctionCallTransformer(
413426
val newCall1 = buildFunctionCall {
414427
source = call.source
415428
this.coneTypeOrNull = returnType
416-
typeArguments += buildTypeProjectionWithVariance {
417-
typeRef = buildResolvedTypeRef {
418-
type = receiverType
429+
if (receiverType != null) {
430+
typeArguments += buildTypeProjectionWithVariance {
431+
typeRef = buildResolvedTypeRef {
432+
type = receiverType
433+
}
434+
variance = Variance.INVARIANT
419435
}
420-
variance = Variance.INVARIANT
421436
}
422437

423438
typeArguments += buildTypeProjectionWithVariance {
@@ -429,11 +444,14 @@ class FunctionCallTransformer(
429444
dispatchReceiver = null
430445
this.explicitReceiver = callExplicitReceiver
431446
extensionReceiver = callExtensionReceiver ?: callDispatchReceiver
432-
argumentList = buildResolvedArgumentList(original = null, linkedMapOf(argument to parameter.fir))
447+
argumentList = buildResolvedArgumentList(
448+
original = null,
449+
linkedMapOf(argument to scopeFunction.valueParameterSymbols[0].fir)
450+
)
433451
calleeReference = buildResolvedNamedReference {
434452
source = call.calleeReference.source
435-
this.name = Name.identifier("let")
436-
resolvedSymbol = resolvedLet
453+
this.name = scopeFunction.name
454+
resolvedSymbol = scopeFunction
437455
}
438456
}
439457
return newCall1
@@ -565,5 +583,9 @@ class FunctionCallTransformer(
565583
return session.symbolProvider.getTopLevelFunctionSymbols(FqName("kotlin"), Name.identifier("let")).single()
566584
}
567585

586+
private fun findRun(): FirFunctionSymbol<*> {
587+
return session.symbolProvider.getTopLevelFunctionSymbols(FqName("kotlin"), Name.identifier("run")).single { it.typeParameterSymbols.size == 1 }
588+
}
589+
568590
private fun String.titleCase() = replaceFirstChar { it.uppercaseChar() }
569591
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/dataFrameOf.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")
22
package org.jetbrains.kotlinx.dataframe.plugin.impl.api
33

4+
import org.jetbrains.kotlin.fir.expressions.FirExpression
5+
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
46
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
57
import org.jetbrains.kotlin.fir.types.commonSuperTypeOrNull
68
import org.jetbrains.kotlin.fir.types.resolvedType
9+
import org.jetbrains.kotlin.fir.types.type
710
import org.jetbrains.kotlin.fir.types.typeContext
811
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
912
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
@@ -36,3 +39,18 @@ class DataFrameBuilderInvoke0 : AbstractSchemaModificationInterpreter() {
3639
return PluginDataFrameSchema(columns)
3740
}
3841
}
42+
43+
class DataFrameOf3 : AbstractSchemaModificationInterpreter() {
44+
val Arguments.columns: List<Interpreter.Success<Pair<*, *>>> by arg()
45+
46+
override fun Arguments.interpret(): PluginDataFrameSchema {
47+
val res = columns.map {
48+
val it = it.value
49+
val name = (it.first as? FirLiteralExpression)?.value as? String
50+
val type = (it.second as? FirExpression)?.resolvedType?.typeArguments?.getOrNull(0)?.type
51+
if (name == null || type == null) return PluginDataFrameSchema(emptyList())
52+
simpleColumnOf(name, type)
53+
}
54+
return PluginDataFrameSchema(res)
55+
}
56+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package org.jetbrains.kotlinx.dataframe.plugin.impl.api
2+
3+
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
4+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
5+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
6+
7+
class PairConstructor : AbstractInterpreter<Pair<*, *>>() {
8+
val Arguments.receiver: Any? by arg(lens = Interpreter.Id)
9+
val Arguments.that: Any? by arg(lens = Interpreter.Id)
10+
override fun Arguments.interpret(): Pair<*, *> {
11+
return receiver to that
12+
}
13+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ fun <T> KotlinTypeFacade.interpret(
136136
is FirCallableReferenceAccess -> {
137137
toKPropertyApproximation(it, session)
138138
}
139-
139+
is FirFunctionCall -> {
140+
it.loadInterpreter()?.let { processor ->
141+
interpret(it, processor, emptyMap(), reporter)
142+
}
143+
}
140144
else -> null
141145
}
142146

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,20 @@ import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
6060
import org.jetbrains.kotlin.fir.expressions.FirGetClassCall
6161
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
6262
import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier
63+
import org.jetbrains.kotlin.fir.expressions.UnresolvedExpressionTypeAccess
6364
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
65+
import org.jetbrains.kotlin.fir.references.resolved
66+
import org.jetbrains.kotlin.fir.references.symbol
67+
import org.jetbrains.kotlin.fir.references.toResolvedNamedFunctionSymbol
6468
import org.jetbrains.kotlin.fir.resolve.fqName
6569
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
6670
import org.jetbrains.kotlin.fir.types.classId
6771
import org.jetbrains.kotlin.fir.types.coneType
72+
import org.jetbrains.kotlin.name.CallableId
6873
import org.jetbrains.kotlin.name.ClassId
74+
import org.jetbrains.kotlin.name.FqName
6975
import org.jetbrains.kotlin.name.Name
76+
import org.jetbrains.kotlin.name.StandardClassIds
7077
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslStringInvoke
7178
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddId
7279
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Aggregate
@@ -76,12 +83,14 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
7683
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
7784
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
7885
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
86+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
7987
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
8088
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
8189
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
8290
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
8391
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
8492
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0
93+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.PairConstructor
8594
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReadExcel
8695
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrame
8796
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameColumn
@@ -91,8 +100,16 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom
91100
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToTop
92101
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Update0
93102
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.UpdateWith0
103+
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
94104

105+
@OptIn(UnresolvedExpressionTypeAccess::class)
95106
internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
107+
if (
108+
calleeReference.toResolvedNamedFunctionSymbol()?.callableId == Names.TO &&
109+
coneTypeOrNull?.classId == Names.PAIR
110+
) {
111+
return PairConstructor()
112+
}
96113
val symbol =
97114
(calleeReference as? FirResolvedNamedReference)?.resolvedSymbol as? FirCallableSymbol ?: return null
98115
val argName = Name.identifier("interpreter")
@@ -208,6 +225,7 @@ internal inline fun <reified T> String.load(): T {
208225
"ToTop" -> ToTop()
209226
"Update0" -> Update0()
210227
"Aggregate" -> Aggregate()
228+
"DataFrameOf3" -> DataFrameOf3()
211229
else -> error("$this")
212230
} as T
213231
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.jetbrains.kotlinx.dataframe.plugin.utils
77

8+
import org.jetbrains.kotlin.name.CallableId
89
import org.jetbrains.kotlin.name.ClassId
910
import org.jetbrains.kotlin.name.FqName
1011
import org.jetbrains.kotlin.name.Name
@@ -50,6 +51,9 @@ object Names {
5051
val LOCAL_DATE_CLASS_ID = kotlinx.datetime.LocalDate::class.classId()
5152
val LOCAL_DATE_TIME_CLASS_ID = kotlinx.datetime.LocalDateTime::class.classId()
5253
val INSTANT_CLASS_ID = kotlinx.datetime.Instant::class.classId()
54+
55+
val PAIR = ClassId(FqName("kotlin"), Name.identifier("Pair"))
56+
val TO = CallableId(FqName("kotlin"), Name.identifier("to"))
5357
}
5458

5559
private fun KClass<*>.classId(): ClassId {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val df = dataFrameOf(
8+
"a" to listOf(1, 2),
9+
"b" to listOf("str1", "str2"),
10+
)
11+
val i: Int = df.a[0]
12+
val str: String = df.b[0]
13+
return "OK"
14+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ public void testDataFrameOf() {
8282
runTest("testData/box/dataFrameOf.kt");
8383
}
8484

85+
@Test
86+
@TestMetadata("dataFrameOf_to.kt")
87+
public void testDataFrameOf_to() {
88+
runTest("testData/box/dataFrameOf_to.kt");
89+
}
90+
8591
@Test
8692
@TestMetadata("dataFrameOf_vararg.kt")
8793
public void testDataFrameOf_vararg() {

0 commit comments

Comments
 (0)