Skip to content

Commit c45a097

Browse files
authored
Merge pull request #942 from Kotlin/toDataFrame-nullable
[Compiler plugin] Propagate nullability in toDataFrame tree conversion
2 parents 2970631 + a49bda7 commit c45a097

File tree

7 files changed

+176
-41
lines changed

7 files changed

+176
-41
lines changed

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ import org.jetbrains.kotlin.fir.symbols.SymbolInternals
2525
import org.jetbrains.kotlin.fir.symbols.impl.ConeClassLikeLookupTagImpl
2626
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
2727
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
28+
import org.jetbrains.kotlin.fir.types.ConeFlexibleType
2829
import org.jetbrains.kotlin.fir.types.ConeKotlinType
30+
import org.jetbrains.kotlin.fir.types.ConeNullability
2931
import org.jetbrains.kotlin.fir.types.ConeStarProjection
3032
import org.jetbrains.kotlin.fir.types.ConeTypeParameterType
33+
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
3134
import org.jetbrains.kotlin.fir.types.canBeNull
3235
import org.jetbrains.kotlin.fir.types.classId
3336
import org.jetbrains.kotlin.fir.types.coneType
@@ -41,15 +44,18 @@ import org.jetbrains.kotlin.fir.types.resolvedType
4144
import org.jetbrains.kotlin.fir.types.toRegularClassSymbol
4245
import org.jetbrains.kotlin.fir.types.toSymbol
4346
import org.jetbrains.kotlin.fir.types.type
47+
import org.jetbrains.kotlin.fir.types.typeContext
4448
import org.jetbrains.kotlin.fir.types.upperBoundIfFlexible
4549
import org.jetbrains.kotlin.fir.types.withArguments
50+
import org.jetbrains.kotlin.fir.types.withNullability
4651
import org.jetbrains.kotlin.name.ClassId
4752
import org.jetbrains.kotlin.name.FqName
4853
import org.jetbrains.kotlin.name.Name
4954
import org.jetbrains.kotlin.name.StandardClassIds
5055
import org.jetbrains.kotlin.name.StandardClassIds.List
5156
import org.jetbrains.kotlinx.dataframe.codeGen.*
5257
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
58+
import org.jetbrains.kotlinx.dataframe.plugin.extensions.wrap
5359
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
5460
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
5561
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
@@ -71,27 +77,31 @@ import java.util.*
7177
class ToDataFrameDsl : AbstractSchemaModificationInterpreter() {
7278
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
7379
val Arguments.body by dsl()
80+
val Arguments.typeArg0: ConeTypeProjection? by arg(lens = Interpreter.Id)
81+
7482
override fun Arguments.interpret(): PluginDataFrameSchema {
7583
val dsl = CreateDataFrameDslImplApproximation()
76-
body(dsl, mapOf("explicitReceiver" to Interpreter.Success(receiver)))
84+
body(dsl, mapOf("typeArg0" to Interpreter.Success(typeArg0)))
7785
return PluginDataFrameSchema(dsl.columns)
7886
}
7987
}
8088

8189
class ToDataFrame : AbstractSchemaModificationInterpreter() {
8290
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
8391
val Arguments.maxDepth: Number by arg(defaultValue = Present(DEFAULT_MAX_DEPTH))
92+
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)
8493

8594
override fun Arguments.interpret(): PluginDataFrameSchema {
86-
return toDataFrame(maxDepth.toInt(), receiver, TraverseConfiguration())
95+
return toDataFrame(maxDepth.toInt(), typeArg0, TraverseConfiguration())
8796
}
8897
}
8998

9099
class ToDataFrameDefault : AbstractSchemaModificationInterpreter() {
91100
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
101+
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)
92102

93103
override fun Arguments.interpret(): PluginDataFrameSchema {
94-
return toDataFrame(DEFAULT_MAX_DEPTH, receiver, TraverseConfiguration())
104+
return toDataFrame(DEFAULT_MAX_DEPTH, typeArg0, TraverseConfiguration())
95105
}
96106
}
97107

@@ -109,14 +119,14 @@ private const val DEFAULT_MAX_DEPTH = 0
109119

110120
class Properties0 : AbstractInterpreter<Unit>() {
111121
val Arguments.dsl: CreateDataFrameDslImplApproximation by arg()
112-
val Arguments.explicitReceiver: FirExpression? by arg()
113122
val Arguments.maxDepth: Int by arg()
114123
val Arguments.body by dsl()
124+
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)
115125

116126
override fun Arguments.interpret() {
117127
dsl.configuration.maxDepth = maxDepth
118128
body(dsl.configuration.traverseConfiguration, emptyMap())
119-
val schema = toDataFrame(dsl.configuration.maxDepth, explicitReceiver, dsl.configuration.traverseConfiguration)
129+
val schema = toDataFrame(dsl.configuration.maxDepth, typeArg0, dsl.configuration.traverseConfiguration)
120130
dsl.columns.addAll(schema.columns())
121131
}
122132
}
@@ -172,8 +182,8 @@ class Exclude1 : AbstractInterpreter<Unit>() {
172182
@OptIn(SymbolInternals::class)
173183
internal fun KotlinTypeFacade.toDataFrame(
174184
maxDepth: Int,
175-
explicitReceiver: FirExpression?,
176-
traverseConfiguration: TraverseConfiguration
185+
arg: ConeTypeProjection,
186+
traverseConfiguration: TraverseConfiguration,
177187
): PluginDataFrameSchema {
178188
fun ConeKotlinType.isValueType() =
179189
this.isArrayTypeOrNullableArrayType ||
@@ -197,7 +207,7 @@ internal fun KotlinTypeFacade.toDataFrame(
197207
val preserveClasses = traverseConfiguration.preserveClasses.mapNotNullTo(mutableSetOf()) { it.classId }
198208
val preserveProperties = traverseConfiguration.preserveProperties.mapNotNullTo(mutableSetOf()) { it.calleeReference.toResolvedPropertySymbol() }
199209

200-
fun convert(classLike: ConeKotlinType, depth: Int): List<SimpleCol> {
210+
fun convert(classLike: ConeKotlinType, depth: Int, makeNullable: Boolean): List<SimpleCol> {
201211
val symbol = classLike.toRegularClassSymbol(session) ?: return emptyList()
202212
val scope = symbol.unsubstitutedScope(session, ScopeSession(), false, FirResolvePhase.STATUS)
203213
val declarations = if (symbol.fir is FirJavaClass) {
@@ -260,7 +270,7 @@ internal fun KotlinTypeFacade.toDataFrame(
260270

261271
val keepSubtree = depth >= maxDepth && !fieldKind.shouldBeConvertedToColumnGroup && !fieldKind.shouldBeConvertedToFrameColumn
262272
if (keepSubtree || returnType.isValueType() || returnType.classId in preserveClasses || it in preserveProperties) {
263-
SimpleDataColumn(name, TypeApproximation(returnType))
273+
SimpleDataColumn(name, TypeApproximation(returnType.withNullability(ConeNullability.create(makeNullable), session.typeContext)))
264274
} else if (
265275
returnType.isSubtypeOf(StandardClassIds.Iterable.constructClassLikeType(arrayOf(ConeStarProjection)), session) ||
266276
returnType.isSubtypeOf(StandardClassIds.Iterable.constructClassLikeType(arrayOf(ConeStarProjection), isNullable = true), session)
@@ -271,30 +281,28 @@ internal fun KotlinTypeFacade.toDataFrame(
271281
else -> session.builtinTypes.nullableAnyType.type
272282
}
273283
if (type.isValueType()) {
274-
SimpleDataColumn(name,
275-
TypeApproximation(
276-
List.constructClassLikeType(
277-
arrayOf(type),
278-
returnType.isNullable
279-
)
280-
)
281-
)
284+
val columnType = List.constructClassLikeType(arrayOf(type), returnType.isNullable)
285+
.withNullability(ConeNullability.create(makeNullable), session.typeContext)
286+
.wrap()
287+
SimpleDataColumn(name, columnType)
282288
} else {
283-
SimpleFrameColumn(name, convert(type, depth + 1))
289+
SimpleFrameColumn(name, convert(type, depth + 1, makeNullable = false))
284290
}
285291
} else {
286-
SimpleColumnGroup(name, convert(returnType, depth + 1))
292+
SimpleColumnGroup(name, convert(returnType, depth + 1, returnType.isNullable || makeNullable))
287293
}
288294
}
289295
}
290296

291-
val receiver = explicitReceiver ?: return PluginDataFrameSchema.EMPTY
292-
val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema.EMPTY
293297
return when {
294298
arg.isStarProjection -> PluginDataFrameSchema.EMPTY
295299
else -> {
296-
val classLike = arg.type as? ConeClassLikeType ?: return PluginDataFrameSchema.EMPTY
297-
val columns = convert(classLike, 0)
300+
val classLike = when (val type = arg.type) {
301+
is ConeClassLikeType -> type
302+
is ConeFlexibleType -> type.upperBound
303+
else -> null
304+
} ?: return PluginDataFrameSchema.EMPTY
305+
val columns = convert(classLike, 0, makeNullable = classLike.isNullable)
298306
PluginDataFrameSchema(columns)
299307
}
300308
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,29 +90,43 @@ fun <T> KotlinTypeFacade.interpret(
9090
val refinedArguments: RefinedArguments = functionCall.collectArgumentExpressions()
9191

9292
val defaultArguments = processor.expectedArguments.filter { it.defaultValue is Present }.map { it.name }.toSet()
93-
val actualArgsMap = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
94-
val conflictingKeys = additionalArguments.keys intersect actualArgsMap.keys
93+
val actualValueArguments = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
94+
val conflictingKeys = additionalArguments.keys intersect actualValueArguments.keys
9595
if (conflictingKeys.isNotEmpty()) {
9696
if (isTest) {
9797
interpretationFrameworkError("Conflicting keys: $conflictingKeys")
9898
}
9999
return null
100100
}
101101
val expectedArgsMap = processor.expectedArguments
102-
.filterNot { it.name.startsWith("typeArg") }
103102
.associateBy { it.name }.toSortedMap().minus(additionalArguments.keys)
104103

105-
val unexpectedArguments = expectedArgsMap.keys - defaultArguments != actualArgsMap.keys - defaultArguments
104+
val typeArguments = buildMap {
105+
functionCall.typeArguments.forEachIndexed { index, firTypeProjection ->
106+
val key = "typeArg$index"
107+
val lens = expectedArgsMap[key]?.lens ?: return@forEachIndexed
108+
val value: Any = if (lens == Interpreter.Id) {
109+
firTypeProjection.toConeTypeProjection()
110+
} else {
111+
val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type
112+
if (type is ConeIntersectionType) return@forEachIndexed
113+
Marker(type)
114+
}
115+
put(key, Interpreter.Success(value))
116+
}
117+
}
118+
119+
val unexpectedArguments = (expectedArgsMap.keys - defaultArguments) != (actualValueArguments.keys + typeArguments.keys - defaultArguments)
106120
if (unexpectedArguments) {
107121
if (isTest) {
108122
val message = buildString {
109123
appendLine("ERROR: Different set of arguments")
110124
appendLine("Implementation class: $processor")
111-
appendLine("Not found in actual: ${expectedArgsMap.keys - actualArgsMap.keys}")
112-
val diff = actualArgsMap.keys - expectedArgsMap.keys
125+
appendLine("Not found in actual: ${expectedArgsMap.keys - actualValueArguments.keys}")
126+
val diff = actualValueArguments.keys - expectedArgsMap.keys
113127
appendLine("Passed, but not expected: ${diff}")
114128
appendLine("add arguments to an interpeter:")
115-
appendLine(diff.map { actualArgsMap[it] })
129+
appendLine(diff.map { actualValueArguments[it] })
116130
}
117131
interpretationFrameworkError(message)
118132
}
@@ -121,6 +135,7 @@ fun <T> KotlinTypeFacade.interpret(
121135

122136
val arguments = mutableMapOf<String, Interpreter.Success<Any?>>()
123137
arguments += additionalArguments
138+
arguments += typeArguments
124139
val interpretationResults = refinedArguments.refinedArguments.mapNotNull {
125140
val name = it.name.identifier
126141
val expectedArgument = expectedArgsMap[name] ?: error("$processor $name")
@@ -269,17 +284,6 @@ fun <T> KotlinTypeFacade.interpret(
269284
value?.let { value1 -> it.name.identifier to value1 }
270285
}
271286

272-
functionCall.typeArguments.forEachIndexed { index, firTypeProjection ->
273-
val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type
274-
if (type is ConeIntersectionType) return@forEachIndexed
275-
// val approximation = TypeApproximationImpl(
276-
// type.classId!!.asFqNameString(),
277-
// type.isMarkedNullable
278-
// )
279-
val approximation = Marker(type)
280-
arguments["typeArg$index"] = Interpreter.Success(approximation)
281-
}
282-
283287
return if (interpretationResults.size == refinedArguments.refinedArguments.size) {
284288
arguments.putAll(interpretationResults)
285289
when (val res = processor.interpret(arguments, this)) {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
@DataSchema
7+
data class D(
8+
val s: String
9+
)
10+
11+
class Subtree(
12+
val p: Int,
13+
val l: List<Int>,
14+
val ld: List<D>,
15+
)
16+
17+
class Root(val a: Subtree)
18+
19+
class MyList(val l: List<Root?>): List<Root?> by l
20+
21+
fun box(): String {
22+
val l = listOf(
23+
Root(Subtree(123, listOf(1), listOf(D("ff")))),
24+
null
25+
)
26+
val df = MyList(l).toDataFrame(maxDepth = 2)
27+
df.compareSchemas(strict = true)
28+
return "OK"
29+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
@DataSchema
7+
data class D(
8+
val s: String
9+
)
10+
11+
fun box(): String {
12+
val df1 = listOf(D("bb"), null).toDataFrame()
13+
df1.schema().print()
14+
df1.compileTimeSchema().print()
15+
return "OK"
16+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
@DataSchema
7+
data class D(
8+
val s: String
9+
)
10+
11+
class Subtree(
12+
val p: Int,
13+
val l: List<Int>,
14+
val ld: List<D>,
15+
)
16+
17+
class Root(val a: Subtree)
18+
19+
fun box(): String {
20+
val l = listOf(
21+
Root(Subtree(123, listOf(1), listOf(D("ff")))),
22+
null
23+
)
24+
val df = l.toDataFrame(maxDepth = 2)
25+
df.compareSchemas(strict = true)
26+
return "OK"
27+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
@DataSchema
7+
data class D(
8+
val s: String
9+
)
10+
11+
class Subtree(
12+
val p: Int,
13+
val l: List<Int>,
14+
val ld: List<D>,
15+
)
16+
17+
class Root(val a: Subtree?)
18+
19+
fun box(): String {
20+
val l = listOf(
21+
Root(Subtree(123, listOf(1), listOf(D("ff")))),
22+
Root(null)
23+
)
24+
val df = l.toDataFrame(maxDepth = 2)
25+
df.compareSchemas(strict = true)
26+
return "OK"
27+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,12 @@ public void testToDataFrame_column() {
418418
runTest("testData/box/toDataFrame_column.kt");
419419
}
420420

421+
@Test
422+
@TestMetadata("toDataFrame_customIterable.kt")
423+
public void testToDataFrame_customIterable() {
424+
runTest("testData/box/toDataFrame_customIterable.kt");
425+
}
426+
421427
@Test
422428
@TestMetadata("toDataFrame_dataSchema.kt")
423429
public void testToDataFrame_dataSchema() {
@@ -436,6 +442,24 @@ public void testToDataFrame_from() {
436442
runTest("testData/box/toDataFrame_from.kt");
437443
}
438444

445+
@Test
446+
@TestMetadata("toDataFrame_nullableList.kt")
447+
public void testToDataFrame_nullableList() {
448+
runTest("testData/box/toDataFrame_nullableList.kt");
449+
}
450+
451+
@Test
452+
@TestMetadata("toDataFrame_nullableListSubtree.kt")
453+
public void testToDataFrame_nullableListSubtree() {
454+
runTest("testData/box/toDataFrame_nullableListSubtree.kt");
455+
}
456+
457+
@Test
458+
@TestMetadata("toDataFrame_nullableSubtree.kt")
459+
public void testToDataFrame_nullableSubtree() {
460+
runTest("testData/box/toDataFrame_nullableSubtree.kt");
461+
}
462+
439463
@Test
440464
@TestMetadata("toDataFrame_superType.kt")
441465
public void testToDataFrame_superType() {

0 commit comments

Comments
 (0)