diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 83dbba1..8679362 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -733,6 +733,29 @@ public actor DataFrame: Sendable { return buildRepartition(numPartitions: numPartitions, shuffle: false) } + /// Groups the ``DataFrame`` using the specified columns, so we can run aggregation on them. + /// - Parameter cols: Grouping column names. + /// - Returns: A ``GroupedData``. + public func groupBy(_ cols: String...) -> GroupedData { + return GroupedData(self, GroupType.groupby, cols) + } + + /// Create a multi-dimensional rollup for the current ``DataFrame`` using the specified columns, so we + /// can run aggregation on them. + /// - Parameter cols: Grouping column names. + /// - Returns: A ``GroupedData``. + public func rollup(_ cols: String...) -> GroupedData { + return GroupedData(self, GroupType.rollup, cols) + } + + /// Create a multi-dimensional cube for the current ``DataFrame`` using the specified columns, so we + /// can run aggregation on them. + /// - Parameter cols: Grouping column names. + /// - Returns: A ``GroupedData``. + public func cube(_ cols: String...) -> GroupedData { + return GroupedData(self, GroupType.cube, cols) + } + /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. public var write: DataFrameWriter { get { diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index d41b5b1..5d75b3d 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -93,6 +93,17 @@ extension String { default: JoinType.inner } } + + var toGroupType: GroupType { + return switch self.lowercased() { + case "groupby": .groupby + case "rollup": .rollup + case "cube": .cube + case "pivot": .pivot + case "groupingsets": .groupingSets + default: .UNRECOGNIZED(-1) + } + } } extension [String: String] { diff --git a/Sources/SparkConnect/GroupedData.swift b/Sources/SparkConnect/GroupedData.swift new file mode 100644 index 0000000..a460832 --- /dev/null +++ b/Sources/SparkConnect/GroupedData.swift @@ -0,0 +1,51 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +public actor GroupedData { + let df: DataFrame + let groupType: GroupType + let groupingCols: [String] + + init(_ df: DataFrame, _ groupType: GroupType, _ groupingCols: [String]) { + self.df = df + self.groupType = groupType + self.groupingCols = groupingCols + } + + public func agg(_ exprs: String...) async -> DataFrame { + var aggregate = Aggregate() + aggregate.input = await (self.df.getPlan() as! Plan).root + aggregate.groupType = self.groupType + aggregate.groupingExpressions = self.groupingCols.map { + var expr = Spark_Connect_Expression() + expr.expressionString = $0.toExpressionString + return expr + } + aggregate.aggregateExpressions = exprs.map { + var expr = Spark_Connect_Expression() + expr.expressionString = $0.toExpressionString + return expr + } + var relation = Relation() + relation.aggregate = aggregate + var plan = Plan() + plan.opType = .root(relation) + return await DataFrame(spark: df.spark, plan: plan) + } +} diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 2858de2..f0dcc04 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -16,6 +16,7 @@ // specific language governing permissions and limitations // under the License. +typealias Aggregate = Spark_Connect_Aggregate typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse typealias Command = Spark_Connect_Command @@ -29,6 +30,7 @@ typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode typealias ExpressionString = Spark_Connect_Expression.ExpressionString typealias Filter = Spark_Connect_Filter +typealias GroupType = Spark_Connect_Aggregate.GroupType typealias Join = Spark_Connect_Join typealias JoinType = Spark_Connect_Join.JoinType typealias KeyValue = Spark_Connect_KeyValue diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 5772120..7fd1403 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -24,6 +24,20 @@ import SparkConnect /// A test suite for `DataFrame` struct DataFrameTests { + let DEALER_TABLE = + """ + VALUES + (100, 'Fremont', 'Honda Civic', 10), + (100, 'Fremont', 'Honda Accord', 15), + (100, 'Fremont', 'Honda CRV', 7), + (200, 'Dublin', 'Honda Civic', 20), + (200, 'Dublin', 'Honda Accord', 10), + (200, 'Dublin', 'Honda CRV', 3), + (300, 'San Jose', 'Honda Civic', 5), + (300, 'San Jose', 'Honda Accord', 8) + dealer (id, city, car_model, quantity) + """ + @Test func sparkSession() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -577,6 +591,61 @@ struct DataFrameTests { #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) await spark.stop() } + + @Test + func groupBy() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(3).groupBy("id").agg("count(*)", "sum(*)", "avg(*)").collect() + #expect(rows == [Row("0", "1", "0", "0.0"), Row("1", "1", "1", "1.0"), Row("2", "1", "2", "2.0")]) + await spark.stop() + } + + @Test + func rollup() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model") + .agg("sum(quantity) sum").orderBy("city", "car_model").collect() + #expect(rows == [ + Row("Dublin", "Honda Accord", "10"), + Row("Dublin", "Honda CRV", "3"), + Row("Dublin", "Honda Civic", "20"), + Row("Dublin", nil, "33"), + Row("Fremont", "Honda Accord", "15"), + Row("Fremont", "Honda CRV", "7"), + Row("Fremont", "Honda Civic", "10"), + Row("Fremont", nil, "32"), + Row("San Jose", "Honda Accord", "8"), + Row("San Jose", "Honda Civic", "5"), + Row("San Jose", nil, "13"), + Row(nil, nil, "78"), + ]) + await spark.stop() + } + + @Test + func cube() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model") + .agg("sum(quantity) sum").orderBy("city", "car_model").collect() + #expect(rows == [ + Row("Dublin", "Honda Accord", "10"), + Row("Dublin", "Honda CRV", "3"), + Row("Dublin", "Honda Civic", "20"), + Row("Dublin", nil, "33"), + Row("Fremont", "Honda Accord", "15"), + Row("Fremont", "Honda CRV", "7"), + Row("Fremont", "Honda Civic", "10"), + Row("Fremont", nil, "32"), + Row("San Jose", "Honda Accord", "8"), + Row("San Jose", "Honda Civic", "5"), + Row("San Jose", nil, "13"), + Row(nil, "Honda Accord", "33"), + Row(nil, "Honda CRV", "10"), + Row(nil, "Honda Civic", "35"), + Row(nil, nil, "78"), + ]) + await spark.stop() + } #endif @Test