Skip to content

Commit 1724a6b

Browse files
committed
[SPARK-51807] Support drop and withColumnRenamed in DataFrame
1 parent 43714e0 commit 1724a6b

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,39 @@ public actor DataFrame: Sendable {
262262
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
263263
}
264264

265+
/// Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column name.
266+
/// - Parameter cols: Column names
267+
/// - Returns: A ``DataFrame`` with subset of columns.
268+
public func drop(_ cols: String...) -> DataFrame {
269+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getDrop(self.plan.root, cols))
270+
}
271+
272+
/// Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain existingName.
273+
/// - Parameters:
274+
/// - existingName: A existing column name to be renamed.
275+
/// - newName: A new column name.
276+
/// - Returns: A ``DataFrame`` with the renamed column.
277+
public func withColumnRenamed(_ existingName: String, _ newName: String) -> DataFrame {
278+
return withColumnRenamed([existingName: newName])
279+
}
280+
281+
/// Returns a new Dataset with columns renamed. This is a no-op if schema doesn't contain existingName.
282+
/// - Parameters:
283+
/// - colNames: A list of existing colum names to be renamed.
284+
/// - newColNames: A list of new column names.
285+
/// - Returns: A ``DataFrame`` with the renamed columns.
286+
public func withColumnRenamed(_ colNames: [String], _ newColNames: [String]) -> DataFrame {
287+
let dic = Dictionary(uniqueKeysWithValues: zip(colNames, newColNames))
288+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, dic))
289+
}
290+
291+
/// Returns a new Dataset with columns renamed. This is a no-op if schema doesn't contain existingName.
292+
/// - Parameter colsMap: A dictionary of existing column name and new column name.
293+
/// - Returns: A ``DataFrame`` with the renamed columns.
294+
public func withColumnRenamed(_ colsMap: [String: String]) -> DataFrame {
295+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, colsMap))
296+
}
297+
265298
/// Return a new ``DataFrame`` with filtered rows using the given expression.
266299
/// - Parameter conditionExpr: A string to filter.
267300
/// - Returns: A ``DataFrame`` with subset of rows.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,17 @@ public actor SparkConnectClient {
335335
return plan
336336
}
337337

338+
static func getWithColumnRenamed(_ child: Relation, _ colsMap: [String: String]) -> Plan {
339+
var withColumnsRenamed = WithColumnsRenamed()
340+
withColumnsRenamed.input = child
341+
withColumnsRenamed.renameColumnsMap = colsMap
342+
var relation = Relation()
343+
relation.withColumnsRenamed = withColumnsRenamed
344+
var plan = Plan()
345+
plan.opType = .root(relation)
346+
return plan
347+
}
348+
338349
static func getFilter(_ child: Relation, _ conditionExpr: String) -> Plan {
339350
var filter = Filter()
340351
filter.input = child
@@ -346,6 +357,17 @@ public actor SparkConnectClient {
346357
return plan
347358
}
348359

360+
static func getDrop(_ child: Relation, _ columnNames: [String]) -> Plan {
361+
var drop = Drop()
362+
drop.input = child
363+
drop.columnNames = columnNames
364+
var relation = Relation()
365+
relation.drop = drop
366+
var plan = Plan()
367+
plan.opType = .root(relation)
368+
return plan
369+
}
370+
349371
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
350372
var sort = Sort()
351373
sort.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest
2323
typealias DataSource = Spark_Connect_Read.DataSource
2424
typealias DataType = Spark_Connect_DataType
2525
typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval
26+
typealias Drop = Spark_Connect_Drop
2627
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
2728
typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
2829
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
@@ -47,5 +48,6 @@ typealias StructType = Spark_Connect_DataType.Struct
4748
typealias Tail = Spark_Connect_Tail
4849
typealias UserContext = Spark_Connect_UserContext
4950
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
51+
typealias WithColumnsRenamed = Spark_Connect_WithColumnsRenamed
5052
typealias WriteOperation = Spark_Connect_WriteOperation
5153
typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,28 @@ struct DataFrameTests {
172172
await spark.stop()
173173
}
174174

175+
@Test
176+
func withColumnRenamed() async throws {
177+
let spark = try await SparkSession.builder.getOrCreate()
178+
#expect(try await spark.range(1).withColumnRenamed("id", "id2").columns == ["id2"])
179+
let df = try await spark.sql("SELECT 1 a, 2 b, 3 c, 4 d")
180+
#expect(try await df.withColumnRenamed(["a": "x", "c": "z"]).columns == ["x", "b", "z", "d"])
181+
// Ignore unknown column names.
182+
#expect(try await df.withColumnRenamed(["unknown": "x"]).columns == ["a", "b", "c", "d"])
183+
await spark.stop()
184+
}
185+
186+
@Test
187+
func drop() async throws {
188+
let spark = try await SparkSession.builder.getOrCreate()
189+
let df = try await spark.sql("SELECT 1 a, 2 b, 3 c, 4 d")
190+
#expect(try await df.drop("a").collect() == [["2", "3", "4"]])
191+
#expect(try await df.drop("b", "c").collect() == [["1", "4"]])
192+
// Ignore unknown column names.
193+
#expect(try await df.drop("x", "y").collect() == [["1", "2", "3", "4"]])
194+
await spark.stop()
195+
}
196+
175197
@Test
176198
func filter() async throws {
177199
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)