diff --git a/core/api/core.api b/core/api/core.api index 89800bfd2d..4e6cf5ee8b 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -2128,6 +2128,7 @@ public abstract interface class org/jetbrains/kotlinx/dataframe/api/ExprColumnsS public abstract interface class org/jetbrains/kotlinx/dataframe/api/FilterColumnsSelectionDsl { public fun filter (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet; + public fun filterColumnGroups (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet; } public abstract interface class org/jetbrains/kotlinx/dataframe/api/FilterColumnsSelectionDsl$Grammar { @@ -5058,6 +5059,16 @@ public abstract interface class org/jetbrains/kotlinx/dataframe/columns/ColumnGr public abstract fun rename (Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup; } +public final class org/jetbrains/kotlinx/dataframe/columns/ColumnGroupWithPath { + public fun (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup;Lorg/jetbrains/kotlinx/dataframe/columns/ColumnPath;)V + public final fun depth ()I + public final fun getData ()Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup; + public final fun getDepth ()I + public final fun getName ()Ljava/lang/String; + public final fun getParentName ()Ljava/lang/String; + public final fun getPath ()Lorg/jetbrains/kotlinx/dataframe/columns/ColumnPath; +} + public class org/jetbrains/kotlinx/dataframe/columns/ColumnKind : java/lang/Enum { public static final field Frame Lorg/jetbrains/kotlinx/dataframe/columns/ColumnKind; public static final field Group Lorg/jetbrains/kotlinx/dataframe/columns/ColumnKind; diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt index 667d71243d..702b31a5b6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt @@ -1,5 +1,6 @@ package org.jetbrains.kotlinx.dataframe.api +import org.jetbrains.kotlinx.dataframe.AnyRow import org.jetbrains.kotlinx.dataframe.ColumnFilter import org.jetbrains.kotlinx.dataframe.ColumnSelector import org.jetbrains.kotlinx.dataframe.ColumnsSelector @@ -8,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.Predicate import org.jetbrains.kotlinx.dataframe.RowFilter import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroupWithPath import org.jetbrains.kotlinx.dataframe.columns.ColumnPath import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.ColumnSet @@ -159,6 +161,25 @@ public interface FilterColumnsSelectionDsl { @Suppress("UNCHECKED_CAST") public fun ColumnSet.filter(predicate: ColumnFilter): ColumnSet = colsInternal(predicate as ColumnFilter<*>).cast() + + /** + * ## Filter [ColumnSet] + * + * #### For example: + * ```kotlin + * df.convert { + * colsAtAnyDepth().colGroups() + * .filter { it.data.containsColumn("myCol") + * }.with { ... } + * ``` + */ + @Suppress("INAPPLICABLE_JVM_NAME") + @JvmName("filterColumnGroups") + public fun ColumnSet.filter(predicate: Predicate>): ColumnSet<*> = + colsInternal { columnWithPath -> + columnWithPath.isColumnGroup() && + predicate(ColumnGroupWithPath(columnWithPath, columnWithPath.path)) + } } // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/columns/ColumnWithPath.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/columns/ColumnWithPath.kt index d2e5f8a2ae..4a8d670f67 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/columns/ColumnWithPath.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/columns/ColumnWithPath.kt @@ -63,3 +63,13 @@ public interface ColumnWithPath : DataColumn { public val ColumnWithPath.depth: Int get() = path.depth() public fun ColumnWithPath(column: DataColumn<*>, path: ColumnPath): ColumnWithPath<*> = column.addPath(path) + +public class ColumnGroupWithPath(public val data: ColumnGroup, public val path: ColumnPath) { + public val name: String get() = data.name() + + public val parentName: String? get() = path.parentName + + public fun depth(): Int = path.depth() + + public val depth: Int get() = path.depth() +} diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt index 009f44a133..2a378a80e5 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/filter.kt @@ -19,4 +19,12 @@ class FilterTests : ColumnsSelectionDslTests() { df.select { all().filter { true } } shouldBe df.select { all() } df.select { all().filter { false } } shouldBe df.select { none() } } + + @Test + fun `filter column group`() { + listOf( + df.select { name }, + df.select { colsAtAnyDepth().colGroups().filter { it.data.containsColumn("firstName") } }, + ) + } }