Skip to content

Commit ee7fca3

Browse files
committed
[SPARK-51504] Support DataFrame.select
1 parent af298fc commit ee7fca3

File tree

5 files changed

+69
-1
lines changed

5 files changed

+69
-1
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public actor DataFrame: Sendable {
3636
/// - Parameters:
3737
/// - spark: A ``SparkSession`` instance to use.
3838
/// - plan: A plan to execute.
39-
init(spark: SparkSession, plan: Plan) async throws {
39+
init(spark: SparkSession, plan: Plan) {
4040
self.spark = spark
4141
self.plan = plan
4242
}
@@ -192,4 +192,9 @@ public actor DataFrame: Sendable {
192192
print(table.render())
193193
}
194194
}
195+
196+
public func select(_ cols: String...) -> DataFrame {
197+
let plan = SparkConnectClient.getProject(self.plan.root, cols)
198+
return DataFrame(spark: self.spark, plan: plan)
199+
}
195200
}

Sources/SparkConnect/Extension.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ extension String {
4545
keyValue.key = self
4646
return keyValue
4747
}
48+
49+
var toUnresolvedAttribute: UnresolvedAttribute {
50+
var attribute = UnresolvedAttribute()
51+
attribute.unparsedIdentifier = self
52+
return attribute
53+
}
4854
}
4955

5056
extension [String: String] {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,20 @@ public actor SparkConnectClient {
252252
request.analyze = .schema(schema)
253253
return request
254254
}
255+
256+
static func getProject(_ child: Relation, _ cols: [String]) -> Plan {
257+
var project = Project()
258+
project.input = child
259+
let expressions: [Spark_Connect_Expression] = cols.map {
260+
var expression = Spark_Connect_Expression()
261+
expression.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
262+
return expression
263+
}
264+
project.expressions = expressions
265+
var relation = Relation()
266+
relation.project = project
267+
var plan = Plan()
268+
plan.opType = .root(relation)
269+
return plan
270+
}
255271
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest
2222
typealias DataType = Spark_Connect_DataType
2323
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
2424
typealias Plan = Spark_Connect_Plan
25+
typealias Project = Spark_Connect_Project
2526
typealias KeyValue = Spark_Connect_KeyValue
2627
typealias Range = Spark_Connect_Range
2728
typealias Relation = Spark_Connect_Relation
2829
typealias SparkConnectService = Spark_Connect_SparkConnectService
2930
typealias UserContext = Spark_Connect_UserContext
31+
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,45 @@ struct DataFrameTests {
6868
await spark.stop()
6969
}
7070

71+
@Test
72+
func selectNone() async throws {
73+
let spark = try await SparkSession.builder.getOrCreate()
74+
let emptySchema = try await spark.range(1).select().schema()
75+
#expect(emptySchema == #"{"struct":{}}"#)
76+
await spark.stop()
77+
}
78+
79+
@Test
80+
func select() async throws {
81+
let spark = try await SparkSession.builder.getOrCreate()
82+
let schema = try await spark.range(1).select("id").schema()
83+
#expect(
84+
schema
85+
== #"{"struct":{"fields":[{"name":"id","dataType":{"long":{}}}]}}"#
86+
)
87+
await spark.stop()
88+
}
89+
90+
@Test
91+
func selectMultipleColumns() async throws {
92+
let spark = try await SparkSession.builder.getOrCreate()
93+
let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema()
94+
#expect(
95+
schema
96+
== #"{"struct":{"fields":[{"name":"col2","dataType":{"integer":{}}},{"name":"col1","dataType":{"integer":{}}}]}}"#
97+
)
98+
await spark.stop()
99+
}
100+
101+
@Test
102+
func selectInvalidColumn() async throws {
103+
let spark = try await SparkSession.builder.getOrCreate()
104+
try await #require(throws: Error.self) {
105+
let _ = try await spark.range(1).select("invalid").schema()
106+
}
107+
await spark.stop()
108+
}
109+
71110
@Test
72111
func table() async throws {
73112
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)