Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,39 @@ public actor DataFrame: Sendable {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
}

/// Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column name.
/// - Parameter cols: Column names
/// - Returns: A ``DataFrame`` with subset of columns.
public func drop(_ cols: String...) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getDrop(self.plan.root, cols))
}

/// Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain existingName.
/// - Parameters:
/// - existingName: A existing column name to be renamed.
/// - newName: A new column name.
/// - Returns: A ``DataFrame`` with the renamed column.
public func withColumnRenamed(_ existingName: String, _ newName: String) -> DataFrame {
return withColumnRenamed([existingName: newName])
}

/// Returns a new Dataset with columns renamed. This is a no-op if schema doesn't contain existingName.
/// - Parameters:
/// - colNames: A list of existing colum names to be renamed.
/// - newColNames: A list of new column names.
/// - Returns: A ``DataFrame`` with the renamed columns.
public func withColumnRenamed(_ colNames: [String], _ newColNames: [String]) -> DataFrame {
let dic = Dictionary(uniqueKeysWithValues: zip(colNames, newColNames))
return DataFrame(spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, dic))
}

/// Returns a new Dataset with columns renamed. This is a no-op if schema doesn't contain existingName.
/// - Parameter colsMap: A dictionary of existing column name and new column name.
/// - Returns: A ``DataFrame`` with the renamed columns.
public func withColumnRenamed(_ colsMap: [String: String]) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, colsMap))
}

/// Return a new ``DataFrame`` with filtered rows using the given expression.
/// - Parameter conditionExpr: A string to filter.
/// - Returns: A ``DataFrame`` with subset of rows.
Expand Down
22 changes: 22 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,17 @@ public actor SparkConnectClient {
return plan
}

static func getWithColumnRenamed(_ child: Relation, _ colsMap: [String: String]) -> Plan {
var withColumnsRenamed = WithColumnsRenamed()
withColumnsRenamed.input = child
withColumnsRenamed.renameColumnsMap = colsMap
var relation = Relation()
relation.withColumnsRenamed = withColumnsRenamed
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getFilter(_ child: Relation, _ conditionExpr: String) -> Plan {
var filter = Filter()
filter.input = child
Expand All @@ -346,6 +357,17 @@ public actor SparkConnectClient {
return plan
}

static func getDrop(_ child: Relation, _ columnNames: [String]) -> Plan {
var drop = Drop()
drop.input = child
drop.columnNames = columnNames
var relation = Relation()
relation.drop = drop
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
var sort = Sort()
sort.input = child
Expand Down
2 changes: 2 additions & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest
typealias DataSource = Spark_Connect_Read.DataSource
typealias DataType = Spark_Connect_DataType
typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval
typealias Drop = Spark_Connect_Drop
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
Expand All @@ -47,5 +48,6 @@ typealias StructType = Spark_Connect_DataType.Struct
typealias Tail = Spark_Connect_Tail
typealias UserContext = Spark_Connect_UserContext
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
typealias WithColumnsRenamed = Spark_Connect_WithColumnsRenamed
typealias WriteOperation = Spark_Connect_WriteOperation
typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval
22 changes: 22 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func withColumnRenamed() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(1).withColumnRenamed("id", "id2").columns == ["id2"])
let df = try await spark.sql("SELECT 1 a, 2 b, 3 c, 4 d")
#expect(try await df.withColumnRenamed(["a": "x", "c": "z"]).columns == ["x", "b", "z", "d"])
// Ignore unknown column names.
#expect(try await df.withColumnRenamed(["unknown": "x"]).columns == ["a", "b", "c", "d"])
await spark.stop()
}

@Test
func drop() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.sql("SELECT 1 a, 2 b, 3 c, 4 d")
#expect(try await df.drop("a").columns == ["b", "c", "d"])
#expect(try await df.drop("b", "c").columns == ["a", "d"])
// Ignore unknown column names.
#expect(try await df.drop("x", "y").columns == ["a", "b", "c", "d"])
await spark.stop()
}

@Test
func filter() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading