Skip to content

Commit 45036c9

Browse files
committed
Update.asFrame {}, now takes Update.filter into account, applying the operation only on the filtered items
1 parent e086f50 commit 45036c9

File tree

3 files changed

+88
-2
lines changed
  • core/src

3 files changed

+88
-2
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dataframe.RowValueExpression
1111
import org.jetbrains.kotlinx.dataframe.RowValueFilter
1212
import org.jetbrains.kotlinx.dataframe.Selector
1313
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
14+
import org.jetbrains.kotlinx.dataframe.impl.api.asFrameImpl
1415
import org.jetbrains.kotlinx.dataframe.impl.api.updateImpl
1516
import org.jetbrains.kotlinx.dataframe.impl.api.updateWithValuePerColumnImpl
1617
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet
@@ -57,7 +58,7 @@ public infix fun <T, C> Update<T, C>.with(expression: UpdateExpression<T, C, C?>
5758
}
5859

5960
public infix fun <T, C, R> Update<T, DataRow<C>>.asFrame(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
60-
df.replace(columns).with { it.asColumnGroup().let { expression(it, it) }.asColumnGroup(it.name()) }
61+
asFrameImpl(expression)
6162

6263
public fun <T, C> Update<T, C>.asNullable(): Update<T, C?> = this as Update<T, C?>
6364

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/update.kt

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,28 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
44
import org.jetbrains.kotlinx.dataframe.AnyRow
55
import org.jetbrains.kotlinx.dataframe.DataColumn
66
import org.jetbrains.kotlinx.dataframe.DataFrame
7+
import org.jetbrains.kotlinx.dataframe.DataFrameExpression
8+
import org.jetbrains.kotlinx.dataframe.DataRow
79
import org.jetbrains.kotlinx.dataframe.RowValueFilter
810
import org.jetbrains.kotlinx.dataframe.Selector
911
import org.jetbrains.kotlinx.dataframe.api.AddDataRow
1012
import org.jetbrains.kotlinx.dataframe.api.Update
13+
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
1114
import org.jetbrains.kotlinx.dataframe.api.cast
1215
import org.jetbrains.kotlinx.dataframe.api.indices
1316
import org.jetbrains.kotlinx.dataframe.api.isEmpty
1417
import org.jetbrains.kotlinx.dataframe.api.name
1518
import org.jetbrains.kotlinx.dataframe.api.replace
19+
import org.jetbrains.kotlinx.dataframe.api.rows
20+
import org.jetbrains.kotlinx.dataframe.api.toColumn
1621
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
1722
import org.jetbrains.kotlinx.dataframe.api.with
1823
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
1924
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
2025
import org.jetbrains.kotlinx.dataframe.columns.size
2126
import org.jetbrains.kotlinx.dataframe.impl.columns.AddDataRowImpl
2227
import org.jetbrains.kotlinx.dataframe.impl.createDataCollector
28+
import org.jetbrains.kotlinx.dataframe.index
2329
import org.jetbrains.kotlinx.dataframe.type
2430
import kotlin.reflect.full.isSubclassOf
2531
import kotlin.reflect.full.withNullability
@@ -40,10 +46,45 @@ internal fun <T, C> Update<T, C>.updateWithValuePerColumnImpl(selector: Selector
4046
}
4147
}
4248

49+
/**
50+
* Implementation for Update As Frame:
51+
* Replaces selected column groups with the result of the expression only where the filter is true.
52+
*/
53+
internal fun <T, C, R> Update<T, DataRow<C>>.asFrameImpl(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
54+
if (df.isEmpty()) df
55+
else df.replace(columns).with {
56+
val src = it.asColumnGroup()
57+
val updatedColumn = expression(src, src).asColumnGroup(src.name())
58+
if (filter == null) {
59+
// If there is no filter, we simply replace the selected column groups with the result of the expression
60+
updatedColumn
61+
} else {
62+
// If there is a filter, we collect the indices of the rows that are inside the filter
63+
val collector = createDataCollector<DataRow<C>>(it.size, it.type)
64+
val indices = buildList {
65+
df.indices().forEach { rowIndex ->
66+
val row = AddDataRowImpl(rowIndex, df, collector.values)
67+
val currentValue = row[src]
68+
69+
if (filter.invoke(row, currentValue)) {
70+
this += rowIndex
71+
collector.add(currentValue)
72+
}
73+
}
74+
}
75+
76+
// Then we only replace the original rows with the updated rows that are inside the filter
77+
src.rows().map {
78+
val index = indices.indexOf(it.index)
79+
if (index == -1) it else updatedColumn[index]
80+
}.toColumn(src.name)
81+
}
82+
}
83+
4384
internal fun <T, C> DataColumn<C>.updateImpl(
4485
df: DataFrame<T>,
4586
filter: RowValueFilter<T, C>?,
46-
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?
87+
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?,
4788
): DataColumn<C> {
4889
val collector = createDataCollector<C>(size, type)
4990
val src = this
@@ -75,6 +116,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
75116
val groups = (values as List<AnyFrame>)
76117
DataColumn.createFrameColumn(name, groups) as DataColumn<T>
77118
}
119+
78120
is ColumnGroup<*> -> {
79121
this.columns().mapIndexed { colIndex, col ->
80122
val newValues = values.map {
@@ -88,6 +130,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
88130
col.updateWith(newValues)
89131
}.toDataFrame().let { DataColumn.createColumnGroup(name, it) } as DataColumn<T>
90132
}
133+
91134
else -> {
92135
var nulls = false
93136
val kclass = type.jvmErasure

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.api
33
import io.kotest.matchers.shouldBe
44
import org.jetbrains.kotlinx.dataframe.DataFrame
55
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
6+
import org.jetbrains.kotlinx.dataframe.size
67
import org.junit.Test
78

89
class UpdateTests {
@@ -14,6 +15,47 @@ class UpdateTests {
1415
df.update { col }.with { 2 } shouldBe df
1516
}
1617

18+
@DataSchema
19+
interface AAndB {
20+
val a: Int
21+
val b: String
22+
}
23+
24+
@DataSchema
25+
data class Update(
26+
override val a: Int,
27+
override val b: String,
28+
val c: Boolean,
29+
) : AAndB
30+
31+
@Test
32+
fun `update asFrame`() {
33+
val df = listOf(
34+
Update(1, "a", true),
35+
Update(2, "b", false),
36+
).toDataFrame()
37+
38+
val group by columnGroup<AAndB>() named "Some Group"
39+
val groupedDf = df.group { a and b }.into { group }
40+
41+
val res = groupedDf
42+
.update { group }
43+
.where { !c }
44+
.asFrame {
45+
// size should still be full df size
46+
size.nrow shouldBe 2
47+
48+
// this will only apply to rows where `.where { !c }` holds
49+
update { a }.with { 0 }
50+
}
51+
52+
val (first, second) = res[{ group }].map { it.a }.toList()
53+
first shouldBe 1
54+
second shouldBe 0
55+
56+
res[{ group }].name() shouldBe "Some Group"
57+
}
58+
1759
@DataSchema
1860
interface SchemaA {
1961
val i: Int?

0 commit comments

Comments
 (0)