Skip to content

Commit a089dcb

Browse files
authored
Merge pull request #283 from Kotlin/asFrame-fix
[Fix] Update.asFrame now takes filter into account.
2 parents c353922 + 1742f43 commit a089dcb

File tree

3 files changed

+91
-2
lines changed
  • core/src

3 files changed

+91
-2
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dataframe.RowValueExpression
1111
import org.jetbrains.kotlinx.dataframe.RowValueFilter
1212
import org.jetbrains.kotlinx.dataframe.Selector
1313
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
14+
import org.jetbrains.kotlinx.dataframe.impl.api.asFrameImpl
1415
import org.jetbrains.kotlinx.dataframe.impl.api.updateImpl
1516
import org.jetbrains.kotlinx.dataframe.impl.api.updateWithValuePerColumnImpl
1617
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet
@@ -57,7 +58,7 @@ public infix fun <T, C> Update<T, C>.with(expression: UpdateExpression<T, C, C?>
5758
}
5859

5960
public infix fun <T, C, R> Update<T, DataRow<C>>.asFrame(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
60-
df.replace(columns).with { it.asColumnGroup().let { expression(it, it) }.asColumnGroup(it.name()) }
61+
asFrameImpl(expression)
6162

6263
public fun <T, C> Update<T, C>.asNullable(): Update<T, C?> = this as Update<T, C?>
6364

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/update.kt

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,28 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
44
import org.jetbrains.kotlinx.dataframe.AnyRow
55
import org.jetbrains.kotlinx.dataframe.DataColumn
66
import org.jetbrains.kotlinx.dataframe.DataFrame
7+
import org.jetbrains.kotlinx.dataframe.DataFrameExpression
8+
import org.jetbrains.kotlinx.dataframe.DataRow
79
import org.jetbrains.kotlinx.dataframe.RowValueFilter
810
import org.jetbrains.kotlinx.dataframe.Selector
911
import org.jetbrains.kotlinx.dataframe.api.AddDataRow
1012
import org.jetbrains.kotlinx.dataframe.api.Update
13+
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
14+
import org.jetbrains.kotlinx.dataframe.api.asDataFrame
1115
import org.jetbrains.kotlinx.dataframe.api.cast
1216
import org.jetbrains.kotlinx.dataframe.api.indices
1317
import org.jetbrains.kotlinx.dataframe.api.isEmpty
1418
import org.jetbrains.kotlinx.dataframe.api.name
1519
import org.jetbrains.kotlinx.dataframe.api.replace
20+
import org.jetbrains.kotlinx.dataframe.api.toColumn
1621
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
1722
import org.jetbrains.kotlinx.dataframe.api.with
1823
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
1924
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
2025
import org.jetbrains.kotlinx.dataframe.columns.size
2126
import org.jetbrains.kotlinx.dataframe.impl.columns.AddDataRowImpl
2227
import org.jetbrains.kotlinx.dataframe.impl.createDataCollector
28+
import org.jetbrains.kotlinx.dataframe.index
2329
import org.jetbrains.kotlinx.dataframe.type
2430
import kotlin.reflect.full.isSubclassOf
2531
import kotlin.reflect.full.withNullability
@@ -40,10 +46,48 @@ internal fun <T, C> Update<T, C>.updateWithValuePerColumnImpl(selector: Selector
4046
}
4147
}
4248

49+
/**
50+
* Implementation for Update As Frame:
51+
* Replaces selected column groups with the result of the expression only where the filter is true.
52+
*/
53+
internal fun <T, C, R> Update<T, DataRow<C>>.asFrameImpl(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
54+
if (df.isEmpty()) df
55+
else df.replace(columns).with {
56+
// First, we create an updated column group with the result of the expression
57+
val srcColumnGroup = it.asColumnGroup()
58+
val updatedColumnGroup = srcColumnGroup
59+
.asDataFrame()
60+
.let { expression(it, it) }
61+
.asColumnGroup(srcColumnGroup.name())
62+
63+
if (filter == null) {
64+
// If there is no filter, we simply return the updated column group
65+
updatedColumnGroup
66+
} else {
67+
// If there is a filter, then we replace the rows of the source column group with the updated column group
68+
// only if they satisfy the filter
69+
srcColumnGroup.replaceRowsIf(from = updatedColumnGroup) {
70+
val srcRow = df[it.index]
71+
val srcValue = srcRow[srcColumnGroup]
72+
73+
filter.invoke(srcRow, srcValue)
74+
}
75+
}
76+
}
77+
78+
private fun <C, R> ColumnGroup<C>.replaceRowsIf(
79+
from: ColumnGroup<R>,
80+
condition: (DataRow<C>) -> Boolean = { true },
81+
): ColumnGroup<C> = values()
82+
.map { if (condition(it)) from[it.index] else it }
83+
.toColumn(name)
84+
.asColumnGroup()
85+
.cast()
86+
4387
internal fun <T, C> DataColumn<C>.updateImpl(
4488
df: DataFrame<T>,
4589
filter: RowValueFilter<T, C>?,
46-
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?
90+
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?,
4791
): DataColumn<C> {
4892
val collector = createDataCollector<C>(size, type)
4993
val src = this
@@ -75,6 +119,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
75119
val groups = (values as List<AnyFrame>)
76120
DataColumn.createFrameColumn(name, groups) as DataColumn<T>
77121
}
122+
78123
is ColumnGroup<*> -> {
79124
this.columns().mapIndexed { colIndex, col ->
80125
val newValues = values.map {
@@ -88,6 +133,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
88133
col.updateWith(newValues)
89134
}.toDataFrame().let { DataColumn.createColumnGroup(name, it) } as DataColumn<T>
90135
}
136+
91137
else -> {
92138
var nulls = false
93139
val kclass = type.jvmErasure

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.api
33
import io.kotest.matchers.shouldBe
44
import org.jetbrains.kotlinx.dataframe.DataFrame
55
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
6+
import org.jetbrains.kotlinx.dataframe.size
67
import org.junit.Test
78

89
class UpdateTests {
@@ -14,6 +15,47 @@ class UpdateTests {
1415
df.update { col }.with { 2 } shouldBe df
1516
}
1617

18+
@DataSchema
19+
interface DataPart {
20+
val a: Int
21+
val b: String
22+
}
23+
24+
@DataSchema
25+
data class Data(
26+
override val a: Int,
27+
override val b: String,
28+
val c: Boolean,
29+
) : DataPart
30+
31+
@Test
32+
fun `update asFrame`() {
33+
val df = listOf(
34+
Data(1, "a", true),
35+
Data(2, "b", false),
36+
).toDataFrame()
37+
38+
val group by columnGroup<DataPart>() named "Some Group"
39+
val groupedDf = df.group { a and b }.into { group }
40+
41+
val res = groupedDf
42+
.update { group }
43+
.where { !c }
44+
.asFrame {
45+
// size should still be full df size
46+
size.nrow shouldBe 2
47+
48+
// this will only apply to rows where `.where { !c }` holds
49+
update { a }.with { 0 }
50+
}
51+
52+
val (first, second) = res[{ group }].map { it.a }.toList()
53+
first shouldBe 1
54+
second shouldBe 0
55+
56+
res[{ group }].name() shouldBe "Some Group"
57+
}
58+
1759
@DataSchema
1860
interface SchemaA {
1961
val i: Int?

0 commit comments

Comments
 (0)