Skip to content

Commit 75790e8

Browse files
committed
Support adding ColumnGroups in add DSL.
1 parent 68974b8 commit 75790e8

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.api
22

33
import org.jetbrains.kotlinx.dataframe.AnyCol
4+
import org.jetbrains.kotlinx.dataframe.AnyColumnGroupAccessor
45
import org.jetbrains.kotlinx.dataframe.AnyFrame
56
import org.jetbrains.kotlinx.dataframe.AnyRow
67
import org.jetbrains.kotlinx.dataframe.Column
@@ -88,6 +89,7 @@ public inline fun <reified R, T, G> GroupBy<T, G>.add(
8889

8990
public class AddDsl<T>(@PublishedApi internal val df: DataFrame<T>) : ColumnsContainer<T> by df, ColumnSelectionDsl<T> {
9091

92+
// TODO: support adding column into path
9193
internal val columns = mutableListOf<AnyCol>()
9294

9395
public fun add(column: Column): Boolean = columns.add(column.resolveSingle(df)!!.data)
@@ -104,6 +106,8 @@ public class AddDsl<T>(@PublishedApi internal val df: DataFrame<T>) : ColumnsCon
104106
): Boolean = add(df.map(name, infer, expression))
105107

106108
public inline infix fun <reified R> String.from(noinline expression: RowExpression<T, R>): Boolean = add(this, Infer.Nulls, expression)
109+
110+
// TODO: use path instead of name
107111
public inline infix fun <reified R> ColumnAccessor<R>.from(noinline expression: RowExpression<T, R>): Boolean = name().from(expression)
108112
public inline infix fun <reified R> KProperty<R>.from(noinline expression: RowExpression<T, R>): Boolean = add(name, Infer.Nulls, expression)
109113

@@ -114,4 +118,21 @@ public class AddDsl<T>(@PublishedApi internal val df: DataFrame<T>) : ColumnsCon
114118
public infix fun Column.into(name: String): Boolean = add(rename(name))
115119
public infix fun <R> ColumnReference<R>.into(column: ColumnAccessor<R>): Boolean = into(column.name())
116120
public infix fun <R> ColumnReference<R>.into(column: KProperty<R>): Boolean = into(column.name)
121+
122+
public operator fun String.invoke(body: AddDsl<T>.() -> Unit): Unit = group(this, body)
123+
public infix fun AnyColumnGroupAccessor.from(body: AddDsl<T>.() -> Unit): Unit = group(this, body)
124+
125+
public fun group(column: AnyColumnGroupAccessor, body: AddDsl<T>.() -> Unit): Unit = group(column.name(), body)
126+
public fun group(name: String, body: AddDsl<T>.() -> Unit) {
127+
val dsl = AddDsl(df)
128+
body(dsl)
129+
add(dsl.columns.toColumnGroup(name))
130+
}
131+
132+
public fun group(body: AddDsl<T>.() -> Unit): AddGroup<T> = AddGroup(body)
133+
134+
public infix fun AddGroup<T>.into(groupName: String): Unit = group(groupName, body)
135+
public infix fun AddGroup<T>.into(column: AnyColumnGroupAccessor): Unit = into(column.name())
117136
}
137+
138+
public data class AddGroup<T>(internal val body: AddDsl<T>.() -> Unit)

src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Modify.kt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -777,8 +777,10 @@ class Modify : TestBase() {
777777
df.add {
778778
"year of birth" from 2021 - age
779779
age gt 18 into "is adult"
780-
name.lastName.length() into "last name length"
781-
"full name" from { name.firstName + " " + name.lastName }
780+
"details" {
781+
name.lastName.length() into "last name length"
782+
"full name" from { name.firstName + " " + name.lastName }
783+
}
782784
}
783785
// SampleEnd
784786
}
@@ -792,14 +794,17 @@ class Modify : TestBase() {
792794
val isAdult = column<Boolean>("is adult")
793795
val fullName = column<String>("full name")
794796
val name by columnGroup()
797+
val details by columnGroup()
795798
val firstName by name.column<String>()
796799
val lastName by name.column<String>()
797800

798801
df.add {
799802
yob from 2021 - age
800803
age gt 18 into isAdult
801-
lastName.length() into lastNameLength
802-
fullName from { firstName() + " " + lastName() }
804+
details from {
805+
lastName.length() into lastNameLength
806+
fullName from { firstName() + " " + lastName() }
807+
}
803808
}
804809
// SampleEnd
805810
}
@@ -810,8 +815,10 @@ class Modify : TestBase() {
810815
df.add {
811816
"year of birth" from 2021 - "age"<Int>()
812817
"age"<Int>() gt 18 into "is adult"
813-
"name"["lastName"]<String>().length() into "last name length"
814-
"full name" from { "name"["firstName"]<String>() + " " + "name"["lastName"]<String>() }
818+
"details" {
819+
"name"["lastName"]<String>().length() into "last name length"
820+
"full name" from { "name"["firstName"]<String>() + " " + "name"["lastName"]<String>() }
821+
}
815822
}
816823
// SampleEnd
817824
}

src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ import org.jetbrains.kotlinx.dataframe.api.frameColumn
6363
import org.jetbrains.kotlinx.dataframe.api.gather
6464
import org.jetbrains.kotlinx.dataframe.api.getColumn
6565
import org.jetbrains.kotlinx.dataframe.api.getColumnGroup
66+
import org.jetbrains.kotlinx.dataframe.api.getColumns
6667
import org.jetbrains.kotlinx.dataframe.api.getFrameColumn
6768
import org.jetbrains.kotlinx.dataframe.api.getValue
6869
import org.jetbrains.kotlinx.dataframe.api.group
@@ -883,14 +884,32 @@ class DataFrameTests : BaseTest() {
883884
fun `add several columns`() {
884885
val now = 2020
885886
val expected = typed.rows().map { now - it.age }
886-
887-
fun AnyFrame.check() = (1..3).forEach { this["year$it"].toList() shouldBe expected }
888-
889-
typed.add {
890-
"year1" from { now - age }
891-
"year2" from now - age
892-
now - age into "year3"
893-
}.check()
887+
val g by columnGroup()
888+
889+
val df = typed.add {
890+
"a" from { now - age }
891+
"b" from now - age
892+
now - age into "c"
893+
"d" {
894+
"f" from { now - age }
895+
}
896+
group {
897+
g from {
898+
add(age.map { now - it }.named("h"))
899+
}
900+
} into "e"
901+
}.remove { allBefore("a") }
902+
903+
df.columnNames() shouldBe listOf("a", "b", "c", "d", "e")
904+
df["d"].kind() shouldBe ColumnKind.Group
905+
df["e"].kind() shouldBe ColumnKind.Group
906+
df.getColumnGroup("d").columnNames() shouldBe listOf("f")
907+
df.getColumnGroup("e").getColumnGroup("g").columnNames() shouldBe listOf("h")
908+
val cols = df.getColumns { allDfs() }
909+
cols.size shouldBe 5
910+
cols.forEach {
911+
it.toList() shouldBe expected
912+
}
894913
}
895914

896915
@Test

0 commit comments

Comments
 (0)