Skip to content

Commit 9031d0b

Browse files
authored
Merge pull request #1127 from Kotlin/unfold
[Compiler plugin] Support `unfold`
2 parents 3e1cca9 + dde2fa8 commit 9031d0b

File tree

9 files changed

+261
-77
lines changed

9 files changed

+261
-77
lines changed

core/api/core.api

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4328,10 +4328,11 @@ public final class org/jetbrains/kotlinx/dataframe/api/TypeConversionsKt {
43284328
}
43294329

43304330
public final class org/jetbrains/kotlinx/dataframe/api/UnfoldKt {
4331-
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
43324331
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
4332+
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KCallable;ILkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
43334333
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
43344334
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
4335+
public static synthetic fun unfold$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KCallable;ILkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
43354336
}
43364337

43374338
public final class org/jetbrains/kotlinx/dataframe/api/UngroupKt {
@@ -5611,6 +5612,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/api/ToSequenceKt {
56115612
public static final fun toSequenceImpl (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/reflect/KType;)Lkotlin/sequences/Sequence;
56125613
}
56135614

5615+
public final class org/jetbrains/kotlinx/dataframe/impl/api/UnfoldKt {
5616+
public static final fun unfoldImpl (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
5617+
}
5618+
56145619
public final class org/jetbrains/kotlinx/dataframe/impl/api/UpdateKt {
56155620
public static final fun updateImpl (Lorg/jetbrains/kotlinx/dataframe/api/Update;Lkotlin/jvm/functions/Function3;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
56165621
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/unfold.kt

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,23 @@ import org.jetbrains.kotlinx.dataframe.ColumnsSelector
66
import org.jetbrains.kotlinx.dataframe.DataColumn
77
import org.jetbrains.kotlinx.dataframe.DataFrame
88
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
9-
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
9+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
10+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
1011
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
11-
import org.jetbrains.kotlinx.dataframe.impl.api.canBeUnfolded
12-
import org.jetbrains.kotlinx.dataframe.impl.api.createDataFrameImpl
13-
import org.jetbrains.kotlinx.dataframe.typeClass
12+
import org.jetbrains.kotlinx.dataframe.impl.api.unfoldImpl
13+
import kotlin.reflect.KCallable
1414
import kotlin.reflect.KProperty
1515

16-
public inline fun <reified T> DataColumn<T>.unfold(): AnyCol =
17-
when (kind()) {
18-
ColumnKind.Group, ColumnKind.Frame -> this
16+
public inline fun <reified T> DataColumn<T>.unfold(vararg roots: KCallable<*>, maxDepth: Int = 0): AnyCol =
17+
unfoldImpl { properties(roots = roots, maxDepth) }
1918

20-
else -> when {
21-
!typeClass.canBeUnfolded -> this
22-
23-
else -> values()
24-
.createDataFrameImpl(typeClass) { (this as CreateDataFrameDsl<T>).properties() }
25-
.asColumnGroup(name())
26-
.asDataColumn()
27-
}
28-
}
29-
30-
public fun <T> DataFrame<T>.unfold(columns: ColumnsSelector<T, *>): DataFrame<T> = replace(columns).with { it.unfold() }
19+
@Refine
20+
@Interpretable("DataFrameUnfold")
21+
public fun <T> DataFrame<T>.unfold(
22+
vararg roots: KCallable<*>,
23+
maxDepth: Int = 0,
24+
columns: ColumnsSelector<T, *>,
25+
): DataFrame<T> = replace(columns).with { it.unfoldImpl { properties(roots = roots, maxDepth) } }
3126

3227
public fun <T> DataFrame<T>.unfold(vararg columns: String): DataFrame<T> = unfold { columns.toColumnSet() }
3328

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package org.jetbrains.kotlinx.dataframe.impl.api
2+
3+
import org.jetbrains.kotlinx.dataframe.AnyCol
4+
import org.jetbrains.kotlinx.dataframe.DataColumn
5+
import org.jetbrains.kotlinx.dataframe.api.CreateDataFrameDsl
6+
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
7+
import org.jetbrains.kotlinx.dataframe.api.asDataColumn
8+
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
9+
import org.jetbrains.kotlinx.dataframe.typeClass
10+
11+
@PublishedApi
12+
internal fun <T> DataColumn<T>.unfoldImpl(body: CreateDataFrameDsl<T>.() -> Unit): AnyCol =
13+
when (kind()) {
14+
ColumnKind.Group, ColumnKind.Frame -> this
15+
16+
else -> when {
17+
!typeClass.canBeUnfolded -> this
18+
19+
else -> values()
20+
.createDataFrameImpl(typeClass) { (this as CreateDataFrameDsl<T>).body() }
21+
.asColumnGroup(name())
22+
.asDataColumn()
23+
}
24+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package org.jetbrains.kotlinx.dataframe.api
2+
3+
import io.kotest.matchers.shouldBe
4+
import io.kotest.matchers.types.shouldBeInstanceOf
5+
import org.jetbrains.kotlinx.dataframe.AnyFrame
6+
import org.junit.Test
7+
import kotlin.reflect.typeOf
8+
9+
class UnfoldTests {
10+
@Test
11+
fun unfold() {
12+
val df = dataFrameOf(
13+
"col" to listOf(A("123", 321)),
14+
)
15+
16+
val res = df.unfold { col("col") }
17+
res[pathOf("col", "str")][0] shouldBe "123"
18+
res[pathOf("col", "i")][0] shouldBe 321
19+
}
20+
21+
@Test
22+
fun `unfold deep`() {
23+
val df1 = dataFrameOf(
24+
"col" to listOf(
25+
Group(
26+
"1",
27+
listOf(
28+
Person("Alice", "Cooper", 15, "London"),
29+
Person("Bob", "Dylan", 45, "Dubai"),
30+
),
31+
),
32+
Group(
33+
"2",
34+
listOf(
35+
Person("Charlie", "Daniels", 20, "Moscow"),
36+
Person("Charlie", "Chaplin", 40, "Milan"),
37+
),
38+
),
39+
),
40+
)
41+
42+
df1.unfold { col("col") }[pathOf("col", "participants")].type() shouldBe typeOf<List<Person>>()
43+
44+
df1.unfold(maxDepth = 2) { col("col") }[pathOf("col", "participants")][0].shouldBeInstanceOf<AnyFrame> {
45+
it["firstName"][0] shouldBe "Alice"
46+
}
47+
}
48+
49+
@Test
50+
fun `keep value type`() {
51+
val values = listOf(1, 2, 3, 4)
52+
val df2 = dataFrameOf("int" to values)
53+
val column = df2.unfold { col("int") }["int"]
54+
column.type() shouldBe typeOf<Int>()
55+
column.values() shouldBe values
56+
}
57+
58+
data class A(val str: String, val i: Int)
59+
60+
data class Person(
61+
val firstName: String,
62+
val lastName: String,
63+
val age: Int,
64+
val city: String?,
65+
)
66+
67+
data class Group(val id: String, val participants: List<Person>)
68+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -197,61 +197,6 @@ internal fun KotlinTypeFacade.toDataFrame(
197197
arg: ConeTypeProjection,
198198
traverseConfiguration: TraverseConfiguration,
199199
): PluginDataFrameSchema {
200-
201-
val anyType = session.builtinTypes.nullableAnyType.type
202-
203-
fun ConeKotlinType.isValueType() =
204-
this.isArrayTypeOrNullableArrayType ||
205-
this.classId == StandardClassIds.Unit ||
206-
this.classId == StandardClassIds.Any ||
207-
this.classId == StandardClassIds.Map ||
208-
this.classId == StandardClassIds.MutableMap ||
209-
this.classId == StandardClassIds.String ||
210-
this.classId in StandardClassIds.primitiveTypes ||
211-
this.classId in StandardClassIds.unsignedTypes ||
212-
classId in setOf(
213-
Names.DURATION_CLASS_ID,
214-
Names.LOCAL_DATE_CLASS_ID,
215-
Names.LOCAL_DATE_TIME_CLASS_ID,
216-
Names.INSTANT_CLASS_ID,
217-
Names.DATE_TIME_PERIOD_CLASS_ID,
218-
Names.DATE_TIME_UNIT_CLASS_ID,
219-
Names.TIME_ZONE_CLASS_ID
220-
) ||
221-
this.isSubtypeOf(
222-
StandardClassIds.Number.constructClassLikeType(emptyArray(), isNullable = true),
223-
session
224-
) ||
225-
this.toRegularClassSymbol(session)?.isEnumClass ?: false ||
226-
this.isSubtypeOf(
227-
Names.TEMPORAL_ACCESSOR_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
228-
) ||
229-
this.isSubtypeOf(
230-
Names.TEMPORAL_AMOUNT_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
231-
)
232-
233-
234-
fun FirNamedFunctionSymbol.isGetterLike(): Boolean {
235-
val functionName = this.name.asString()
236-
return (functionName.startsWith("get") || functionName.startsWith("is")) &&
237-
this.valueParameterSymbols.isEmpty() &&
238-
this.typeParameterSymbols.isEmpty()
239-
}
240-
241-
fun ConeKotlinType.hasProperties(): Boolean {
242-
val symbol = this.toRegularClassSymbol(session) as? FirClassSymbol<*> ?: return false
243-
val scope = symbol.unsubstitutedScope(
244-
session,
245-
ScopeSession(),
246-
withForcedTypeCalculator = false,
247-
memberRequiredPhase = null
248-
)
249-
250-
return scope.collectAllProperties().any { it.visibility == Visibilities.Public } ||
251-
scope.collectAllFunctions().any { it.visibility == Visibilities.Public && it.isGetterLike() }
252-
}
253-
254-
255200
val excludes =
256201
traverseConfiguration.excludeProperties.mapNotNullTo(mutableSetOf()) { it.calleeReference.toResolvedPropertySymbol() }
257202
val excludedClasses = traverseConfiguration.excludeClasses.mapTo(mutableSetOf()) { it.argument.resolvedType }
@@ -322,7 +267,7 @@ internal fun KotlinTypeFacade.toDataFrame(
322267

323268
val keepSubtree =
324269
depth >= maxDepth && !fieldKind.shouldBeConvertedToColumnGroup && !fieldKind.shouldBeConvertedToFrameColumn
325-
if (keepSubtree || returnType.isValueType() || returnType.classId in preserveClasses || it in preserveProperties) {
270+
if (keepSubtree || returnType.isValueType(session) || returnType.classId in preserveClasses || it in preserveProperties) {
326271
SimpleDataColumn(
327272
name,
328273
TypeApproximation(
@@ -349,7 +294,7 @@ internal fun KotlinTypeFacade.toDataFrame(
349294
ConeStarProjection -> session.builtinTypes.nullableAnyType.type
350295
else -> session.builtinTypes.nullableAnyType.type
351296
}
352-
if (type.isValueType()) {
297+
if (type.isValueType(session)) {
353298
val columnType = List.constructClassLikeType(arrayOf(type), returnType.isNullable)
354299
.withNullability(ConeNullability.create(makeNullable), session.typeContext)
355300
.wrap()
@@ -364,7 +309,7 @@ internal fun KotlinTypeFacade.toDataFrame(
364309
}
365310

366311
arg.type?.let { type ->
367-
if (type.isValueType() || !type.hasProperties()) {
312+
if (!type.canBeUnfolded(session)) {
368313
return PluginDataFrameSchema(listOf(simpleColumnOf("value", type)))
369314
}
370315
}
@@ -383,6 +328,60 @@ internal fun KotlinTypeFacade.toDataFrame(
383328
}
384329
}
385330

331+
fun ConeKotlinType.canBeUnfolded(session: FirSession): Boolean =
332+
!isValueType(session) && hasProperties(session)
333+
334+
private fun ConeKotlinType.isValueType(session: FirSession) =
335+
this.isArrayTypeOrNullableArrayType ||
336+
this.classId == StandardClassIds.Unit ||
337+
this.classId == StandardClassIds.Any ||
338+
this.classId == StandardClassIds.Map ||
339+
this.classId == StandardClassIds.MutableMap ||
340+
this.classId == StandardClassIds.String ||
341+
this.classId in StandardClassIds.primitiveTypes ||
342+
this.classId in StandardClassIds.unsignedTypes ||
343+
classId in setOf(
344+
Names.DURATION_CLASS_ID,
345+
Names.LOCAL_DATE_CLASS_ID,
346+
Names.LOCAL_DATE_TIME_CLASS_ID,
347+
Names.INSTANT_CLASS_ID,
348+
Names.DATE_TIME_PERIOD_CLASS_ID,
349+
Names.DATE_TIME_UNIT_CLASS_ID,
350+
Names.TIME_ZONE_CLASS_ID
351+
) ||
352+
this.isSubtypeOf(
353+
StandardClassIds.Number.constructClassLikeType(emptyArray(), isNullable = true),
354+
session
355+
) ||
356+
this.toRegularClassSymbol(session)?.isEnumClass ?: false ||
357+
this.isSubtypeOf(
358+
Names.TEMPORAL_ACCESSOR_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
359+
) ||
360+
this.isSubtypeOf(
361+
Names.TEMPORAL_AMOUNT_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
362+
)
363+
364+
365+
private fun ConeKotlinType.hasProperties(session: FirSession): Boolean {
366+
val symbol = this.toRegularClassSymbol(session) as? FirClassSymbol<*> ?: return false
367+
val scope = symbol.unsubstitutedScope(
368+
session,
369+
ScopeSession(),
370+
withForcedTypeCalculator = false,
371+
memberRequiredPhase = null
372+
)
373+
374+
return scope.collectAllProperties().any { it.visibility == Visibilities.Public } ||
375+
scope.collectAllFunctions().any { it.visibility == Visibilities.Public && it.isGetterLike() }
376+
}
377+
378+
private fun FirNamedFunctionSymbol.isGetterLike(): Boolean {
379+
val functionName = this.name.asString()
380+
return (functionName.startsWith("get") || functionName.startsWith("is")) &&
381+
this.valueParameterSymbols.isEmpty() &&
382+
this.typeParameterSymbols.isEmpty()
383+
}
384+
386385
// org.jetbrains.kotlinx.dataframe.codeGen.getFieldKind
387386
private fun ConeKotlinType.getFieldKind(session: FirSession) = FieldKind.of(
388387
this,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package org.jetbrains.kotlinx.dataframe.plugin.impl.api
2+
3+
import org.jetbrains.kotlinx.dataframe.api.replace
4+
import org.jetbrains.kotlinx.dataframe.api.with
5+
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
6+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
7+
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
8+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
9+
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
10+
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
11+
import org.jetbrains.kotlinx.dataframe.plugin.impl.asDataColumn
12+
import org.jetbrains.kotlinx.dataframe.plugin.impl.asDataFrame
13+
import org.jetbrains.kotlinx.dataframe.plugin.impl.asSimpleColumn
14+
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
15+
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
16+
import org.jetbrains.kotlinx.dataframe.plugin.impl.toPluginDataFrameSchema
17+
18+
class DataFrameUnfold : AbstractSchemaModificationInterpreter() {
19+
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
20+
val Arguments.properties by ignore()
21+
val Arguments.maxDepth: Int by arg(defaultValue = Present(0))
22+
val Arguments.columns: ColumnsResolver by arg()
23+
24+
override fun Arguments.interpret(): PluginDataFrameSchema {
25+
return receiver.asDataFrame().replace { columns }.with {
26+
val column = it.asSimpleColumn() as? SimpleDataColumn
27+
if (column != null) {
28+
if (!column.type.type.canBeUnfolded(session)) {
29+
it
30+
} else {
31+
SimpleColumnGroup(it.name(), toDataFrame(maxDepth, column.type.type, TraverseConfiguration()).columns()).asDataColumn()
32+
}
33+
} else {
34+
it
35+
}
36+
}.toPluginDataFrameSchema()
37+
}
38+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ConcatWithKeys
101101
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
102102
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
103103
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
104+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameUnfold
104105
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameXs
105106
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop0
106107
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop1
@@ -464,6 +465,7 @@ internal inline fun <reified T> String.load(): T {
464465
"DataFrameXs" -> DataFrameXs()
465466
"GroupByXs" -> GroupByXs()
466467
"ConcatWithKeys" -> ConcatWithKeys()
468+
"DataFrameUnfold" -> DataFrameUnfold()
467469
else -> error("$this")
468470
} as T
469471
}

0 commit comments

Comments
 (0)