Skip to content

Commit 2d08793

Browse files
committed
getAt
1 parent da89e93 commit 2d08793

File tree

14 files changed

+225
-100
lines changed

14 files changed

+225
-100
lines changed

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/ColumnsSelectionDsl.kt

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4031,11 +4031,7 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
40314031
public fun <C> ColumnSet<C>.cols(
40324032
firstIndex: Int,
40334033
vararg otherIndices: Int,
4034-
): ColumnSet<C> = headPlusArray(firstIndex, otherIndices).let { indices ->
4035-
transform { list ->
4036-
indices.map { list[it] }
4037-
}
4038-
}
4034+
): ColumnSet<C> = colsInternal(headPlusArray(firstIndex, otherIndices)) as ColumnSet<C>
40394035

40404036
public operator fun <C> ColumnSet<C>.get(
40414037
firstIndex: Int,
@@ -4045,9 +4041,7 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
40454041
public fun SingleColumn<*>.cols(
40464042
firstIndex: Int,
40474043
vararg otherIndices: Int,
4048-
): ColumnSet<*> = headPlusArray(firstIndex, otherIndices).let { indices ->
4049-
transform { it.flatMap { it.children().let { children -> indices.map { children[it] } } } }
4050-
}
4044+
): ColumnSet<*> = colsInternal(headPlusArray(firstIndex, otherIndices))
40514045

40524046
/**
40534047
*
@@ -4092,12 +4086,12 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
40924086
// region ranges
40934087

40944088
public fun <C> ColumnSet<C>.cols(range: IntRange): ColumnSet<C> =
4095-
transform { it.subList(range.first, range.last + 1) }
4089+
colsInternal(range) as ColumnSet<C>
40964090

40974091
public operator fun <C> ColumnSet<C>.get(range: IntRange): ColumnSet<C> = cols(range)
40984092

40994093
public fun SingleColumn<*>.cols(range: IntRange): ColumnSet<*> =
4100-
transform { it.flatMap { it.children().subList(range.first, range.last + 1) } }
4094+
colsInternal(range)
41014095

41024096
/**
41034097
*
@@ -4810,6 +4804,28 @@ internal fun ColumnSet<*>.colsInternal(predicate: ColumnFilter<*>): ColumnSet<*>
48104804
}.filter(predicate)
48114805
}
48124806

4807+
internal fun ColumnSet<*>.colsInternal(indices: IntArray): ColumnSet<*> =
4808+
transform {
4809+
if (isSingleColumnGroup(it)) {
4810+
it.single().children()
4811+
} else {
4812+
it
4813+
}.let { cols ->
4814+
indices.map { cols[it] }
4815+
}
4816+
}
4817+
4818+
internal fun ColumnSet<*>.colsInternal(range: IntRange): ColumnSet<*> =
4819+
transform {
4820+
if (isSingleColumnGroup(it)) {
4821+
it.single().children()
4822+
} else {
4823+
it
4824+
}.let { cols ->
4825+
cols.subList(range.first, range.last + 1)
4826+
}
4827+
}
4828+
48134829
internal fun ColumnSet<*>.allInternal(): ColumnSet<*> =
48144830
transform {
48154831
if (isSingleColumnGroup(it)) {

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/columns/ColumnSet.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ public interface ColumnSet<out C> {
3333
public fun interface ColumnSetTransformer {
3434

3535
public fun transform(columnSet: ColumnSet<*>): ColumnSet<*>
36-
37-
public operator fun invoke(columnSet: ColumnSet<*>): ColumnSet<*> = transform(columnSet)
3836
}
3937

38+
public operator fun ColumnSetTransformer.invoke(columnSet: ColumnSet<*>): ColumnSet<*> = transform(columnSet)
39+
4040
public class ColumnResolutionContext internal constructor(
4141
internal val df: DataFrame<*>,
4242
internal val unresolvedColumnsPolicy: UnresolvedColumnsPolicy,

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/columns/SingleColumn.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public interface SingleColumn<out C> : ColumnSet<C> {
1616
context: ColumnResolutionContext,
1717
): List<ColumnWithPath<C>> = resolveSingle(context)?.let { listOf(it) } ?: emptyList()
1818

19+
/** By default, we transform the current SingleColumn using the transformer and then resolve it */
1920
override fun resolveAfterTransform(
2021
context: ColumnResolutionContext,
2122
transformer: ColumnSetTransformer,
@@ -25,6 +26,7 @@ public interface SingleColumn<out C> : ColumnSet<C> {
2526
public fun resolveSingle(context: ColumnResolutionContext): ColumnWithPath<C>?
2627
}
2728

29+
2830
@OptIn(ExperimentalContracts::class)
2931
public fun ColumnSet<*>.isSingleColumn(): Boolean {
3032
contract {

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ColumnAccessorImpl.kt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@ import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
66
import org.jetbrains.kotlinx.dataframe.api.cast
77
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
88
import org.jetbrains.kotlinx.dataframe.api.toPath
9-
import org.jetbrains.kotlinx.dataframe.columns.ColumnAccessor
10-
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
11-
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
12-
import org.jetbrains.kotlinx.dataframe.columns.ColumnResolutionContext
13-
import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath
9+
import org.jetbrains.kotlinx.dataframe.columns.*
1410

1511
internal class ColumnAccessorImpl<T>(val path: ColumnPath) : ColumnAccessor<T> {
1612

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/Utils.kt

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ internal fun <A, B> SingleColumn<A>.transformSingle(
8686
?.let(converter)
8787
?: emptyList()
8888

89+
// calling resolveAfterTransform up the receiver column set call chain
8990
override fun resolveAfterTransform(
9091
context: ColumnResolutionContext,
9192
transformer: ColumnSetTransformer,
@@ -94,25 +95,33 @@ internal fun <A, B> SingleColumn<A>.transformSingle(
9495
.flatMap(converter)
9596
}
9697

97-
internal fun <A, B> SingleColumn<A>.transformRemainingSingle(
98-
converter: (ColumnWithPath<A>) -> ColumnWithPath<B>,
99-
): SingleColumn<B> = object : SingleColumn<B> {
100-
override fun resolveSingle(context: ColumnResolutionContext): ColumnWithPath<B>? =
101-
this@transformRemainingSingle
102-
.resolveSingle(context)
103-
?.let(converter)
104-
}
98+
//internal fun <A, B> SingleColumn<A>.transformRemainingSingle(
99+
// converter: (ColumnWithPath<A>) -> ColumnWithPath<B>,
100+
//): SingleColumn<B> = object : SingleColumn<B> {
101+
// override fun resolveSingle(context: ColumnResolutionContext): ColumnWithPath<B>? =
102+
// this@transformRemainingSingle
103+
// .resolveSingle(context)
104+
// ?.let(converter)
105+
//
106+
// override fun resolveAfterTransform(
107+
// context: ColumnResolutionContext,
108+
// transformer: ColumnSetTransformer
109+
// ): List<ColumnWithPath<B>> =
110+
// [email protected](context, transformer)
111+
//
112+
//}
105113

106114
internal fun <A> ColumnSet<A>.wrap(): ColumnSet<A> = object : ColumnSet<A> {
107115

108116
override fun resolve(context: ColumnResolutionContext): List<ColumnWithPath<A>> =
109117
this@wrap.resolve(context)
110118

119+
// applying transformer here
111120
override fun resolveAfterTransform(
112121
context: ColumnResolutionContext,
113122
transformer: ColumnSetTransformer,
114123
): List<ColumnWithPath<A>> =
115-
transformer.transform(this@wrap).cast<A>()
124+
transformer(this@wrap).cast<A>()
116125
.resolve(context)
117126
}
118127

@@ -123,39 +132,37 @@ internal fun <C> ColumnSet<C>.recursivelyImpl(
123132

124133
val flattenTransformer = object : ColumnSetTransformer {
125134

126-
private fun flattenColumnWithPaths(list: List<ColumnWithPath<*>>, columnSet: ColumnSet<*>): List<ColumnWithPath<*>> {
135+
override fun transform(columnSet: ColumnSet<*>): ColumnSet<*> = columnSet.transform { list ->
127136
val cols =
128137
if (columnSet.isSingleColumnGroup(list)) {
129138
list.single().children()
130139
} else {
131140
list
132141
}
133142

134-
return if (includeTopLevel) {
143+
if (includeTopLevel) {
135144
cols.flattenRecursively()
136145
} else {
137146
cols
138147
.filter { it.isColumnGroup() }
139148
.flatMap { it.children().flattenRecursively() }
140149
}.filter { includeGroups || !it.isColumnGroup() }
141150
}
142-
143-
override fun transform(columnSet: ColumnSet<*>): ColumnSet<*> = columnSet.transform {
144-
flattenColumnWithPaths(it, columnSet)
145-
}
146151
}
147152

153+
// calling resolveAfterTransform up the receiver column set call chain
148154
override fun resolve(
149155
context: ColumnResolutionContext,
150156
): List<ColumnWithPath<C>> =
151157
this@recursivelyImpl
152158
.resolveAfterTransform(context = context, transformer = flattenTransformer)
153159

160+
// applying transformer here
154161
override fun resolveAfterTransform(
155162
context: ColumnResolutionContext,
156163
transformer: ColumnSetTransformer,
157164
): List<ColumnWithPath<C>> =
158-
transformer.transform(this@recursivelyImpl).cast<C>()
165+
transformer(this@recursivelyImpl).cast<C>()
159166
.resolveAfterTransform(context = context, transformer = flattenTransformer)
160167
}
161168

@@ -167,11 +174,12 @@ internal fun <A, B> ColumnSet<A>.transform(
167174
.resolve(context)
168175
.let(converter)
169176

177+
// applying transformer here
170178
override fun resolveAfterTransform(
171179
context: ColumnResolutionContext,
172180
transformer: ColumnSetTransformer,
173181
): List<ColumnWithPath<B>> =
174-
transformer.transform(this@transform).cast<A>()
182+
transformer(this@transform).cast<A>()
175183
.resolve(context)
176184
.let { converter(it) }
177185
}
@@ -184,11 +192,12 @@ internal fun <A, B> ColumnSet<A>.transformWithContext(
184192
.resolve(context)
185193
.let { converter(context, it) }
186194

195+
// applying transformer here
187196
override fun resolveAfterTransform(
188197
context: ColumnResolutionContext,
189198
transformer: ColumnSetTransformer,
190199
): List<ColumnWithPath<B>> =
191-
transformer.transform(this@transformWithContext).cast<A>()
200+
transformer(this@transformWithContext).cast<A>()
192201
.resolve(context)
193202
.let { converter(context, it) }
194203
}
@@ -197,6 +206,7 @@ internal fun <T> ColumnSet<T>.singleImpl() = object : SingleColumn<T> {
197206
override fun resolveSingle(context: ColumnResolutionContext): ColumnWithPath<T>? =
198207
this@singleImpl.resolve(context).singleOrNull()
199208

209+
// passing back the transformer to the previous call
200210
override fun resolveAfterTransform(
201211
context: ColumnResolutionContext,
202212
transformer: ColumnSetTransformer,
@@ -207,11 +217,21 @@ internal fun <T> ColumnSet<T>.singleImpl() = object : SingleColumn<T> {
207217
)
208218
}
209219

210-
internal fun <T> ColumnSet<T>.getAt(index: Int) = object : SingleColumn<T> {
220+
internal fun <T> ColumnSet<T>.getAt(index: Int): SingleColumn<T> = object : SingleColumn<T> {
211221
override fun resolveSingle(context: ColumnResolutionContext): ColumnWithPath<T>? =
212222
this@getAt
213223
.resolve(context)
214224
.getOrNull(index)
225+
226+
// passing back the transformer to the previous call
227+
override fun resolveAfterTransform(
228+
context: ColumnResolutionContext,
229+
transformer: ColumnSetTransformer,
230+
): List<ColumnWithPath<T>> =
231+
this@getAt.resolveAfterTransform(
232+
context = context,
233+
transformer = transformer,
234+
)
215235
}
216236

217237
internal fun <T> ColumnSet<T>.getChildrenAt(index: Int): ColumnSet<Any?> =

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/missing/MissingColumnGroup.kt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@ import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedBody
88
import org.jetbrains.kotlinx.dataframe.api.asDataColumn
99
import org.jetbrains.kotlinx.dataframe.api.cast
1010
import org.jetbrains.kotlinx.dataframe.api.name
11-
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
12-
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
13-
import org.jetbrains.kotlinx.dataframe.columns.ColumnResolutionContext
14-
import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath
11+
import org.jetbrains.kotlinx.dataframe.columns.*
1512
import org.jetbrains.kotlinx.dataframe.columns.UnresolvedColumnsPolicy
1613
import org.jetbrains.kotlinx.dataframe.impl.columns.DataColumnGroup
1714
import org.jetbrains.kotlinx.dataframe.impl.columns.addPath

core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/recursively.kt

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package org.jetbrains.kotlinx.dataframe.api
22

33
import io.kotest.matchers.shouldBe
4+
import io.kotest.matchers.shouldNotBe
45
import org.jetbrains.kotlinx.dataframe.alsoDebug
56
import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath
67
import org.jetbrains.kotlinx.dataframe.impl.columns.recursivelyImpl
78
import org.jetbrains.kotlinx.dataframe.impl.columns.singleImpl
89
import org.jetbrains.kotlinx.dataframe.impl.columns.transform
910
import org.jetbrains.kotlinx.dataframe.samples.api.TestBase
11+
import org.jetbrains.kotlinx.dataframe.samples.api.city
12+
import org.jetbrains.kotlinx.dataframe.samples.api.firstName
1013
import org.jetbrains.kotlinx.dataframe.samples.api.name
1114
import org.junit.Test
1215

@@ -24,6 +27,10 @@ class Recursively : TestBase() {
2427
this.map { it.name to it.path } shouldBe other.map { it.name to it.path }
2528
}
2629

30+
infix fun List<ColumnWithPath<*>>.shouldNotBe(other: List<ColumnWithPath<*>>) {
31+
this.map { it.name to it.path } shouldNotBe other.map { it.name to it.path }
32+
}
33+
2734
private val recursivelyGoal = dfGroup.getColumnsWithPaths { dfs { true } }
2835
.sortedBy { it.name }
2936

@@ -34,19 +41,45 @@ class Recursively : TestBase() {
3441
.sortedBy { it.name }
3542

3643
@Test
37-
fun first() {
38-
dfGroup.select {
39-
first { it.data.any { it == "Alice" } }
40-
.recursively()
41-
}.alsoDebug()
42-
43-
dfGroup.select {
44-
first { it.data.any { it == "London" } }.recursively()
45-
}.alsoDebug()
44+
fun `first, last, and single`() {
45+
listOf(
46+
dfGroup.select { name.firstName.firstName },
47+
48+
dfGroup.select { first { it.data.any { it == "Alice" } }.recursively() },
49+
dfGroup.select { last { it.data.any { it == "Alice" } }.recursively() },
50+
dfGroup.select { single { it.data.any { it == "Alice" } }.recursively() },
51+
).shouldAllBeEqual()
52+
53+
listOf(
54+
dfGroup.select { city },
55+
56+
dfGroup.select { first { it.data.any { it == "London" } }.recursively() },
57+
dfGroup.select { last { it.data.any { it == "London" } }.recursively() },
58+
dfGroup.select { single { it.data.any { it == "London" } }.recursively() },
59+
).shouldAllBeEqual()
60+
}
61+
62+
@Test
63+
fun `get at`() {
64+
dfGroup.getColumnsWithPaths { it[0].recursively() }.print()
65+
66+
// dfGroup.getColumnsWithPaths { recursively()[0] }.print()
67+
}
68+
69+
@Test
70+
fun `combination`() {
71+
dfGroup.getColumnsWithPaths {
72+
cols { it.name in listOf("name", "firstName") }
73+
.last().recursively()
74+
} shouldNotBe
75+
dfGroup.getColumnsWithPaths {
76+
cols { it.name in listOf("name", "firstName") }.recursively()
77+
.last().recursively()
78+
}
4679
}
4780

4881
@Test
49-
fun recursively() {
82+
fun `recursively`() {
5083
dfGroup.getColumnsWithPaths { recursively() }.sortedBy { it.name } shouldBe recursivelyGoal
5184
dfGroup.getColumnsWithPaths { rec(includeGroups = false) }.sortedBy { it.name } shouldBe recursivelyNoGroups
5285
}

0 commit comments

Comments
 (0)