Skip to content

Commit 7818bea

Browse files
committed
Fixes Issue #1221 using Nikita's solution of #663
1 parent 9e2fa0d commit 7818bea

20 files changed

+138
-54
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cast.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,5 @@ public fun <C> TransformableColumnSet<*>.cast(): TransformableColumnSet<C> = thi
9999
public fun <C> TransformableSingleColumn<*>.cast(): TransformableSingleColumn<C> = this as TransformableSingleColumn<C>
100100

101101
public fun <C> ColumnReference<*>.cast(): ColumnReference<C> = this as ColumnReference<C>
102+
103+
public fun <T, G> GroupBy<*, *>.cast(): GroupBy<T, G> = this as GroupBy<T, G>

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/codeGen/ReplCodeGenerator.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.jetbrains.kotlinx.dataframe.impl.codeGen
22

33
import org.jetbrains.kotlinx.dataframe.AnyFrame
44
import org.jetbrains.kotlinx.dataframe.AnyRow
5+
import org.jetbrains.kotlinx.dataframe.api.GroupBy
56
import org.jetbrains.kotlinx.dataframe.codeGen.Code
67
import org.jetbrains.kotlinx.dataframe.codeGen.CodeWithTypeCastGenerator
78
import kotlin.reflect.KClass
@@ -11,9 +12,10 @@ internal interface ReplCodeGenerator {
1112

1213
fun process(df: AnyFrame, property: KProperty<*>? = null): CodeWithTypeCastGenerator
1314

14-
fun process(row: AnyRow, property: KProperty<*>? = null): CodeWithConverter
1515
fun process(row: AnyRow, property: KProperty<*>? = null): CodeWithTypeCastGenerator
1616

17+
fun process(groupBy: GroupBy<*, *>): CodeWithTypeCastGenerator
18+
1719
fun process(markerClass: KClass<*>): Code
1820

1921
companion object {

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/codeGen/ReplCodeGeneratorImpl.kt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,27 @@ internal class ReplCodeGeneratorImpl : ReplCodeGenerator {
7979
return generate(schema = targetSchema, name = markerInterfacePrefix, isOpen = true)
8080
}
8181

82-
fun generate(schema: DataFrameSchema, name: String, isOpen: Boolean): CodeWithConverter {
82+
override fun process(groupBy: GroupBy<*, *>): CodeWithTypeCastGenerator {
83+
val key = generate(
84+
schema = groupBy.keys.schema(),
85+
name = markerInterfacePrefix + "Keys",
86+
isOpen = false,
87+
)
88+
val group = generate(
89+
schema = groupBy.groups.schema.value,
90+
name = markerInterfacePrefix + "Groups",
91+
isOpen = false,
92+
)
93+
94+
val keyTypeName = (key.typeCastGenerator as TypeCastGenerator.DataFrameApi).types.single()
95+
val groupTypeName = (group.typeCastGenerator as TypeCastGenerator.DataFrameApi).types.single()
96+
97+
return CodeWithTypeCastGenerator(
98+
declarations = key.declarations + "\n" + group.declarations,
99+
typeCastGenerator = TypeCastGenerator.DataFrameApi(keyTypeName, groupTypeName),
100+
)
101+
}
102+
83103
fun generate(schema: DataFrameSchema, name: String, isOpen: Boolean): CodeWithTypeCastGenerator {
84104
val result = generator.generate(
85105
schema = schema,

dataframe-jupyter/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ internal class Integration(private val notebook: Notebook, private val options:
153153
null
154154
}
155155

156+
private fun KotlinKernelHost.updateGroupByVariable(
157+
instance: GroupBy<*, *>,
158+
property: KProperty<*>,
159+
codeGen: ReplCodeGenerator,
160+
): VariableName? =
161+
execute(
162+
codeWithTypeCastGenerator = codeGen.process(instance),
163+
property = property,
164+
type = GroupBy::class.createStarProjectedType(false),
165+
)
166+
156167
override fun Builder.onLoaded() {
157168
if (version != null) {
158169
if (enableExperimentalCsv?.toBoolean() == true) {
@@ -291,6 +302,7 @@ internal class Integration(private val notebook: Notebook, private val options:
291302
is AnyRow -> updateAnyRowVariable(instance, property, codeGen)
292303
is AnyFrame -> updateAnyFrameVariable(instance, property, codeGen)
293304
is ImportDataSchema -> updateImportDataSchemaVariable(instance, property)
305+
is GroupBy<*, *> -> updateGroupByVariable(instance, property, codeGen)
294306
else -> error("${instance::class} should not be handled by Dataframe field handler")
295307
}
296308
}
@@ -300,7 +312,8 @@ internal class Integration(private val notebook: Notebook, private val options:
300312
value is ColumnGroup<*> ||
301313
value is AnyRow ||
302314
value is AnyFrame ||
303-
value is ImportDataSchema
315+
value is ImportDataSchema ||
316+
value is GroupBy<*, *>
304317
})
305318

306319
fun KotlinKernelHost.addDataSchemas(classes: List<KClass<*>>) {

dataframe-jupyter/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/CodeGenerationTests.kt

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,51 @@ class CodeGenerationTests : DataFrameJupyterTest() {
9292
df2.group.a
9393
""".checkCompilation()
9494
}
95+
96+
// Issue #1221, #663
97+
@Test
98+
fun `GroupBy code generation`() {
99+
@Language("kts")
100+
val a = """
101+
val ab = dataFrameOf("a", "b")(1, 2)
102+
ab.groupBy { a }.aggregate { sum { b } into "bSum" }
103+
""".checkCompilation()
104+
105+
@Language("kts")
106+
val b = """
107+
val ab = dataFrameOf("a", "b")(1, 2)
108+
val grouped = ab.groupBy { a }
109+
grouped.aggregate { sum { b } into "bSum" }
110+
""".checkCompilation()
111+
112+
@Language("kts")
113+
val c = """
114+
val grouped = dataFrameOf("a", "b")(1, 2).groupBy("a")
115+
grouped.aggregate { sum { b } into "bSum" }
116+
""".checkCompilation()
117+
118+
@Language("kts")
119+
val d = """
120+
val grouped = dataFrameOf("a", "b")(1, 2).groupBy("a")
121+
grouped.keys.a
122+
""".checkCompilation()
123+
124+
@Language("kts")
125+
val e = """
126+
val grouped = dataFrameOf("a", "b")(1, 2).groupBy { "a"<Int>() named "k" }
127+
grouped.keys.k
128+
""".checkCompilation()
129+
130+
@Language("kts")
131+
val f = """
132+
val groupBy = dataFrameOf("a")("1", "11", "2", "22").groupBy { expr { "a"<String>().length } named "k" }
133+
groupBy.keys.k
134+
""".checkCompilation()
135+
136+
@Language("kts")
137+
val g = """
138+
val groupBy = dataFrameOf("a")("1", "11", "2", "22").groupBy { expr { "a"<String>().length } named "k" }.add("newCol") { 42 }
139+
groupBy.aggregate { newCol into "newCol" }
140+
""".checkCompilation()
141+
}
95142
}

docs/StardustDocs/resources/api/generate_docs/notebook_test_generate_docs_1.html

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@
177177
</style>
178178
</head>
179179
<body>
180-
<table class="dataframe" id="df_-553648128"></table>
180+
<table class="dataframe" id="df_-637534208"></table>
181181

182182
<p class="dataframe_description"></p>
183183
</body>
@@ -458,23 +458,23 @@
458458

459459
/*<!--*/
460460
call_DataFrame(function() { DataFrame.addTable({ cols: [{ name: "<span title=\"user: String\">user</span>", children: [], rightAlign: false, values: ["Alice","Bob"] },
461-
{ name: "<span title=\"orders: DataFrame<*>\">orders</span>", children: [], rightAlign: false, values: [{ frameId: -553648127, value: "<b>DataFrame 2 x 2</b>" },{ frameId: -553648126, value: "<b>DataFrame 3 x 2</b>" }] },
462-
], id: -553648128, rootId: -553648128, totalRows: 2 } ) });
461+
{ name: "<span title=\"orders: DataFrame<*>\">orders</span>", children: [], rightAlign: false, values: [{ frameId: -637534207, value: "<b>DataFrame 2 x 2</b>" },{ frameId: -637534206, value: "<b>DataFrame 3 x 2</b>" }] },
462+
], id: -637534208, rootId: -637534208, totalRows: 2 } ) });
463463
/*-->*/
464464

465465
/*<!--*/
466466
call_DataFrame(function() { DataFrame.addTable({ cols: [{ name: "<span title=\"orderId: Int\">orderId</span>", children: [], rightAlign: true, values: ["<span class=\"formatted\" title=\"\"><span class=\"numbers\">101</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">102</span></span>"] },
467467
{ name: "<span title=\"amount: Double\">amount</span>", children: [], rightAlign: true, values: ["<span class=\"formatted\" title=\"\"><span class=\"numbers\">50.0</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">75.5</span></span>"] },
468-
], id: -553648127, rootId: -553648128, totalRows: 2 } ) });
468+
], id: -637534207, rootId: -637534208, totalRows: 2 } ) });
469469
/*-->*/
470470

471471
/*<!--*/
472472
call_DataFrame(function() { DataFrame.addTable({ cols: [{ name: "<span title=\"orderId: Int\">orderId</span>", children: [], rightAlign: true, values: ["<span class=\"formatted\" title=\"\"><span class=\"numbers\">103</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">104</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">105</span></span>"] },
473473
{ name: "<span title=\"amount: Double\">amount</span>", children: [], rightAlign: true, values: ["<span class=\"formatted\" title=\"\"><span class=\"numbers\">20.0</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">30.0</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">25.0</span></span>"] },
474-
], id: -553648126, rootId: -553648128, totalRows: 3 } ) });
474+
], id: -637534206, rootId: -637534208, totalRows: 3 } ) });
475475
/*-->*/
476476

477-
call_DataFrame(function() { DataFrame.renderTable(-553648128) });
477+
call_DataFrame(function() { DataFrame.renderTable(-637534208) });
478478

479479
function sendHeight() {
480480
const table = document.querySelector('table.dataframe');

docs/StardustDocs/resources/api/rename/notebook_test_rename_3.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@
177177
</style>
178178
</head>
179179
<body>
180-
<table class="dataframe" id="df_-553648125"></table>
180+
<table class="dataframe" id="df_-637534205"></table>
181181

182182
<p class="dataframe_description"></p>
183183
</body>
@@ -460,10 +460,10 @@
460460
call_DataFrame(function() { DataFrame.addTable({ cols: [{ name: "<span title=\"ColumnA: Int\">ColumnA</span>", children: [], rightAlign: true, values: ["<span class=\"formatted\" title=\"\"><span class=\"numbers\">1</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">2</span></span>"] },
461461
{ name: "<span title=\"column_b: String\">column_b</span>", children: [], rightAlign: false, values: ["a","b"] },
462462
{ name: "<span title=\"COLUMN-C: Boolean\">COLUMN-C</span>", children: [], rightAlign: false, values: ["true","false"] },
463-
], id: -553648125, rootId: -553648125, totalRows: 2 } ) });
463+
], id: -637534205, rootId: -637534205, totalRows: 2 } ) });
464464
/*-->*/
465465

466-
call_DataFrame(function() { DataFrame.renderTable(-553648125) });
466+
call_DataFrame(function() { DataFrame.renderTable(-637534205) });
467467

468468
function sendHeight() {
469469
const table = document.querySelector('table.dataframe');

docs/StardustDocs/resources/api/rename/notebook_test_rename_4.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@
177177
</style>
178178
</head>
179179
<body>
180-
<table class="dataframe" id="df_-553648124"></table>
180+
<table class="dataframe" id="df_-637534204"></table>
181181

182182
<p class="dataframe_description"></p>
183183
</body>
@@ -460,10 +460,10 @@
460460
call_DataFrame(function() { DataFrame.addTable({ cols: [{ name: "<span title=\"columnA: Int\">columnA</span>", children: [], rightAlign: true, values: ["<span class=\"formatted\" title=\"\"><span class=\"numbers\">1</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">2</span></span>"] },
461461
{ name: "<span title=\"column_b: String\">column_b</span>", children: [], rightAlign: false, values: ["a","b"] },
462462
{ name: "<span title=\"columnC: Boolean\">columnC</span>", children: [], rightAlign: false, values: ["true","false"] },
463-
], id: -553648124, rootId: -553648124, totalRows: 2 } ) });
463+
], id: -637534204, rootId: -637534204, totalRows: 2 } ) });
464464
/*-->*/
465465

466-
call_DataFrame(function() { DataFrame.renderTable(-553648124) });
466+
call_DataFrame(function() { DataFrame.renderTable(-637534204) });
467467

468468
function sendHeight() {
469469
const table = document.querySelector('table.dataframe');

docs/StardustDocs/resources/api/rename/notebook_test_rename_5.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@
177177
</style>
178178
</head>
179179
<body>
180-
<table class="dataframe" id="df_-553648123"></table>
180+
<table class="dataframe" id="df_-637534203"></table>
181181

182182
<p class="dataframe_description"></p>
183183
</body>
@@ -460,10 +460,10 @@
460460
call_DataFrame(function() { DataFrame.addTable({ cols: [{ name: "<span title=\"columnA: Int\">columnA</span>", children: [], rightAlign: true, values: ["<span class=\"formatted\" title=\"\"><span class=\"numbers\">1</span></span>","<span class=\"formatted\" title=\"\"><span class=\"numbers\">2</span></span>"] },
461461
{ name: "<span title=\"columnB: String\">columnB</span>", children: [], rightAlign: false, values: ["a","b"] },
462462
{ name: "<span title=\"columnC: Boolean\">columnC</span>", children: [], rightAlign: false, values: ["true","false"] },
463-
], id: -553648123, rootId: -553648123, totalRows: 2 } ) });
463+
], id: -637534203, rootId: -637534203, totalRows: 2 } ) });
464464
/*-->*/
465465

466-
call_DataFrame(function() { DataFrame.renderTable(-553648123) });
466+
call_DataFrame(function() { DataFrame.renderTable(-637534203) });
467467

468468
function sendHeight() {
469469
const table = document.querySelector('table.dataframe');

0 commit comments

Comments
 (0)