Skip to content

Commit 2fb55a0

Browse files
committed
[SPARK-51879] Support groupBy/rollup/cube in DataFrame
### What changes were proposed in this pull request? This PR aims to support `groupBy`, `rollup`, and `cube` API in `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No, these are additional APIs. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #87 from dongjoon-hyun/SPARK-51879. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 9d68da5 commit 2fb55a0

File tree

5 files changed

+156
-0
lines changed

5 files changed

+156
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,29 @@ public actor DataFrame: Sendable {
733733
return buildRepartition(numPartitions: numPartitions, shuffle: false)
734734
}
735735

736+
/// Groups the ``DataFrame`` using the specified columns, so we can run aggregation on them.
737+
/// - Parameter cols: Grouping column names.
738+
/// - Returns: A ``GroupedData``.
739+
public func groupBy(_ cols: String...) -> GroupedData {
740+
return GroupedData(self, GroupType.groupby, cols)
741+
}
742+
743+
/// Create a multi-dimensional rollup for the current ``DataFrame`` using the specified columns, so we
744+
/// can run aggregation on them.
745+
/// - Parameter cols: Grouping column names.
746+
/// - Returns: A ``GroupedData``.
747+
public func rollup(_ cols: String...) -> GroupedData {
748+
return GroupedData(self, GroupType.rollup, cols)
749+
}
750+
751+
/// Create a multi-dimensional cube for the current ``DataFrame`` using the specified columns, so we
752+
/// can run aggregation on them.
753+
/// - Parameter cols: Grouping column names.
754+
/// - Returns: A ``GroupedData``.
755+
public func cube(_ cols: String...) -> GroupedData {
756+
return GroupedData(self, GroupType.cube, cols)
757+
}
758+
736759
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
737760
public var write: DataFrameWriter {
738761
get {

Sources/SparkConnect/Extension.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ extension String {
9393
default: JoinType.inner
9494
}
9595
}
96+
97+
var toGroupType: GroupType {
98+
return switch self.lowercased() {
99+
case "groupby": .groupby
100+
case "rollup": .rollup
101+
case "cube": .cube
102+
case "pivot": .pivot
103+
case "groupingsets": .groupingSets
104+
default: .UNRECOGNIZED(-1)
105+
}
106+
}
96107
}
97108

98109
extension [String: String] {
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
public actor GroupedData {
21+
let df: DataFrame
22+
let groupType: GroupType
23+
let groupingCols: [String]
24+
25+
init(_ df: DataFrame, _ groupType: GroupType, _ groupingCols: [String]) {
26+
self.df = df
27+
self.groupType = groupType
28+
self.groupingCols = groupingCols
29+
}
30+
31+
public func agg(_ exprs: String...) async -> DataFrame {
32+
var aggregate = Aggregate()
33+
aggregate.input = await (self.df.getPlan() as! Plan).root
34+
aggregate.groupType = self.groupType
35+
aggregate.groupingExpressions = self.groupingCols.map {
36+
var expr = Spark_Connect_Expression()
37+
expr.expressionString = $0.toExpressionString
38+
return expr
39+
}
40+
aggregate.aggregateExpressions = exprs.map {
41+
var expr = Spark_Connect_Expression()
42+
expr.expressionString = $0.toExpressionString
43+
return expr
44+
}
45+
var relation = Relation()
46+
relation.aggregate = aggregate
47+
var plan = Plan()
48+
plan.opType = .root(relation)
49+
return await DataFrame(spark: df.spark, plan: plan)
50+
}
51+
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// specific language governing permissions and limitations
1717
// under the License.
1818

19+
typealias Aggregate = Spark_Connect_Aggregate
1920
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
2021
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
2122
typealias Command = Spark_Connect_Command
@@ -29,6 +30,7 @@ typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
2930
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
3031
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
3132
typealias Filter = Spark_Connect_Filter
33+
typealias GroupType = Spark_Connect_Aggregate.GroupType
3234
typealias Join = Spark_Connect_Join
3335
typealias JoinType = Spark_Connect_Join.JoinType
3436
typealias KeyValue = Spark_Connect_KeyValue

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ import SparkConnect
2424

2525
/// A test suite for `DataFrame`
2626
struct DataFrameTests {
27+
let DEALER_TABLE =
28+
"""
29+
VALUES
30+
(100, 'Fremont', 'Honda Civic', 10),
31+
(100, 'Fremont', 'Honda Accord', 15),
32+
(100, 'Fremont', 'Honda CRV', 7),
33+
(200, 'Dublin', 'Honda Civic', 20),
34+
(200, 'Dublin', 'Honda Accord', 10),
35+
(200, 'Dublin', 'Honda CRV', 3),
36+
(300, 'San Jose', 'Honda Civic', 5),
37+
(300, 'San Jose', 'Honda Accord', 8)
38+
dealer (id, city, car_model, quantity)
39+
"""
40+
2741
@Test
2842
func sparkSession() async throws {
2943
let spark = try await SparkSession.builder.getOrCreate()
@@ -577,6 +591,61 @@ struct DataFrameTests {
577591
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
578592
await spark.stop()
579593
}
594+
595+
@Test
596+
func groupBy() async throws {
597+
let spark = try await SparkSession.builder.getOrCreate()
598+
let rows = try await spark.range(3).groupBy("id").agg("count(*)", "sum(*)", "avg(*)").collect()
599+
#expect(rows == [Row("0", "1", "0", "0.0"), Row("1", "1", "1", "1.0"), Row("2", "1", "2", "2.0")])
600+
await spark.stop()
601+
}
602+
603+
@Test
604+
func rollup() async throws {
605+
let spark = try await SparkSession.builder.getOrCreate()
606+
let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model")
607+
.agg("sum(quantity) sum").orderBy("city", "car_model").collect()
608+
#expect(rows == [
609+
Row("Dublin", "Honda Accord", "10"),
610+
Row("Dublin", "Honda CRV", "3"),
611+
Row("Dublin", "Honda Civic", "20"),
612+
Row("Dublin", nil, "33"),
613+
Row("Fremont", "Honda Accord", "15"),
614+
Row("Fremont", "Honda CRV", "7"),
615+
Row("Fremont", "Honda Civic", "10"),
616+
Row("Fremont", nil, "32"),
617+
Row("San Jose", "Honda Accord", "8"),
618+
Row("San Jose", "Honda Civic", "5"),
619+
Row("San Jose", nil, "13"),
620+
Row(nil, nil, "78"),
621+
])
622+
await spark.stop()
623+
}
624+
625+
@Test
626+
func cube() async throws {
627+
let spark = try await SparkSession.builder.getOrCreate()
628+
let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model")
629+
.agg("sum(quantity) sum").orderBy("city", "car_model").collect()
630+
#expect(rows == [
631+
Row("Dublin", "Honda Accord", "10"),
632+
Row("Dublin", "Honda CRV", "3"),
633+
Row("Dublin", "Honda Civic", "20"),
634+
Row("Dublin", nil, "33"),
635+
Row("Fremont", "Honda Accord", "15"),
636+
Row("Fremont", "Honda CRV", "7"),
637+
Row("Fremont", "Honda Civic", "10"),
638+
Row("Fremont", nil, "32"),
639+
Row("San Jose", "Honda Accord", "8"),
640+
Row("San Jose", "Honda Civic", "5"),
641+
Row("San Jose", nil, "13"),
642+
Row(nil, "Honda Accord", "33"),
643+
Row(nil, "Honda CRV", "10"),
644+
Row(nil, "Honda Civic", "35"),
645+
Row(nil, nil, "78"),
646+
])
647+
await spark.stop()
648+
}
580649
#endif
581650

582651
@Test

0 commit comments

Comments
 (0)