Skip to content

Commit 5684325

Browse files
committed
[SPARK-51996] Support describe and summary in DataFrame
### What changes were proposed in this pull request? This PR aims to support `describe` and `summary` API of `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #112 from dongjoon-hyun/SPARK-51996. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent ccaa92b commit 5684325

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ import Synchronization
164164
/// - ``sample(_:_:)``
165165
/// - ``sample(_:)``
166166
///
167+
/// ### Statistics
168+
/// - ``describe(_:)``
169+
/// - ``summary(_:)``
170+
///
167171
/// ### Utility Methods
168172
/// - ``isEmpty()``
169173
/// - ``isLocal()``
@@ -495,6 +499,25 @@ public actor DataFrame: Sendable {
495499
return DataFrame(spark: self.spark, plan: plan)
496500
}
497501

502+
/// Computes basic statistics for numeric and string columns, including count, mean, stddev, min,
503+
/// and max. If no columns are given, this function computes statistics for all numerical or
504+
/// string columns.
505+
/// - Parameter cols: Column names.
506+
/// - Returns: A ``DataFrame`` containing basic statistics.
507+
public func describe(_ cols: String...) -> DataFrame {
508+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getDescribe(self.plan.root, cols))
509+
}
510+
511+
/// Computes specified statistics for numeric and string columns. Available statistics are:
512+
/// count, mean, stddev, min, max, arbitrary approximate percentiles specified as a percentage (e.g. 75%)
513+
/// count_distinct, approx_count_distinct . If no statistics are given, this function computes count, mean,
514+
/// stddev, min, approximate quartiles (percentiles at 25%, 50%, and 75%), and max.
515+
/// - Parameter statistics: Statistics names.
516+
/// - Returns: A ``DataFrame`` containing specified statistics.
517+
public func summary(_ statistics: String...) -> DataFrame {
518+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSummary(self.plan.root, statistics))
519+
}
520+
498521
/// Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain existingName.
499522
/// - Parameters:
500523
/// - existingName: A existing column name to be renamed.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,28 @@ public actor SparkConnectClient {
474474
return plan
475475
}
476476

477+
static func getDescribe(_ child: Relation, _ cols: [String]) -> Plan {
478+
var describe = Spark_Connect_StatDescribe()
479+
describe.input = child
480+
describe.cols = cols
481+
var relation = Relation()
482+
relation.describe = describe
483+
var plan = Plan()
484+
plan.opType = .root(relation)
485+
return plan
486+
}
487+
488+
static func getSummary(_ child: Relation, _ statistics: [String]) -> Plan {
489+
var summary = Spark_Connect_StatSummary()
490+
summary.input = child
491+
summary.statistics = statistics
492+
var relation = Relation()
493+
relation.summary = summary
494+
var plan = Plan()
495+
plan.opType = .root(relation)
496+
return plan
497+
}
498+
477499
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
478500
var sort = Sort()
479501
sort.input = child

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,28 @@ struct DataFrameTests {
682682
await spark.stop()
683683
}
684684

685+
@Test
686+
func describe() async throws {
687+
let spark = try await SparkSession.builder.getOrCreate()
688+
let df = try await spark.range(10)
689+
let expected = [Row("10"), Row("4.5"), Row("3.0276503540974917"), Row("0"), Row("9")]
690+
#expect(try await df.describe().select("id").collect() == expected)
691+
#expect(try await df.describe("id").select("id").collect() == expected)
692+
await spark.stop()
693+
}
694+
695+
@Test
696+
func summary() async throws {
697+
let spark = try await SparkSession.builder.getOrCreate()
698+
let expected = [
699+
Row("10"), Row("4.5"), Row("3.0276503540974917"),
700+
Row("0"), Row("2"), Row("4"), Row("7"), Row("9")
701+
]
702+
#expect(try await spark.range(10).summary().select("id").collect() == expected)
703+
#expect(try await spark.range(10).summary("min", "max").select("id").collect() == [Row("0"), Row("9")])
704+
await spark.stop()
705+
}
706+
685707
@Test
686708
func groupBy() async throws {
687709
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)