Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ import Synchronization
/// - ``show(_:_:_:)``
///
/// ### Transformation Operations
/// - ``toDF(_:)``
/// - ``select(_:)``
/// - ``selectExpr(_:)``
/// - ``filter(_:)``
Expand All @@ -100,6 +101,9 @@ import Synchronization
/// - ``limit(_:)``
/// - ``offset(_:)``
/// - ``drop(_:)``
/// - ``dropDuplicates(_:)``
/// - ``dropDuplicatesWithinWatermark(_:)``
/// - ``distinct()``
/// - ``withColumnRenamed(_:_:)``
///
/// ### Join Operations
Expand Down Expand Up @@ -440,13 +444,25 @@ public actor DataFrame: Sendable {
return DataFrame(spark: self.spark, plan: plan)
}

/// Projects a set of expressions and returns a new ``DataFrame``.
/// Selects a subset of existing columns using column names.
/// - Parameter cols: Column names
/// - Returns: A ``DataFrame`` with subset of columns.
public func select(_ cols: String...) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
}

/// Selects a subset of existing columns using column names.
/// - Parameter cols: Column names
/// - Returns: A ``DataFrame`` with subset of columns.
public func toDF(_ cols: String...) -> DataFrame {
let df = if cols.isEmpty {
DataFrame(spark: self.spark, plan: self.plan)
} else {
DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
}
return df
}

/// Projects a set of expressions and returns a new ``DataFrame``.
/// - Parameter exprs: Expression strings
/// - Returns: A ``DataFrame`` with subset of columns.
Expand All @@ -461,6 +477,24 @@ public actor DataFrame: Sendable {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getDrop(self.plan.root, cols))
}

/// Returns a new ``DataFrame`` that contains only the unique rows from this ``DataFrame``.
/// This is an alias for `distinct`. If column names are given, Spark considers only those columns.
/// - Parameter cols: Column names
/// - Returns: A ``DataFrame``.
public func dropDuplicates(_ cols: String...) -> DataFrame {
let plan = SparkConnectClient.getDropDuplicates(self.plan.root, cols, withinWatermark: false)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new Dataset with duplicates rows removed, within watermark.
/// If column names are given, Spark considers only those columns.
/// - Parameter cols: Column names
/// - Returns: A ``DataFrame``.
public func dropDuplicatesWithinWatermark(_ cols: String...) -> DataFrame {
let plan = SparkConnectClient.getDropDuplicates(self.plan.root, cols, withinWatermark: true)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain existingName.
/// - Parameters:
/// - existingName: A existing column name to be renamed.
Expand Down Expand Up @@ -1108,6 +1142,13 @@ public actor DataFrame: Sendable {
return buildRepartition(numPartitions: numPartitions, shuffle: false)
}

/// Returns a new ``Dataset`` that contains only the unique rows from this ``Dataset``.
/// This is an alias for `dropDuplicates`.
/// - Returns: A `DataFrame`.
public func distinct() -> DataFrame {
return dropDuplicates()
}

/// Groups the DataFrame using the specified columns.
///
/// This method is used to perform aggregations on groups of data.
Expand Down
19 changes: 19 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,25 @@ public actor SparkConnectClient {
return plan
}

static func getDropDuplicates(
_ child: Relation,
_ columnNames: [String],
withinWatermark: Bool = false
) -> Plan {
var deduplicate = Spark_Connect_Deduplicate()
deduplicate.input = child
if columnNames.isEmpty {
deduplicate.allColumnsAsKeys = true
} else {
deduplicate.columnNames = columnNames
}
var relation = Relation()
relation.deduplicate = deduplicate
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
var sort = Sort()
sort.input = child
Expand Down
35 changes: 35 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ struct DataFrameTests {
@Test
func select() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(1).select().columns.isEmpty)
let schema = try await spark.range(1).select("id").schema
#expect(
schema
Expand All @@ -191,6 +192,14 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func toDF() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(1).toDF().columns == ["id"])
#expect(try await spark.range(1).toDF("id").columns == ["id"])
await spark.stop()
}

@Test
func selectMultipleColumns() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down Expand Up @@ -647,6 +656,32 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func distinct() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)")
#expect(try await df.distinct().count() == 3)
await spark.stop()
}

@Test
func dropDuplicates() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)")
#expect(try await df.dropDuplicates().count() == 3)
#expect(try await df.dropDuplicates("a").count() == 3)
await spark.stop()
}

@Test
func dropDuplicatesWithinWatermark() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)")
#expect(try await df.dropDuplicatesWithinWatermark().count() == 3)
#expect(try await df.dropDuplicatesWithinWatermark("a").count() == 3)
await spark.stop()
}

@Test
func groupBy() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading