Skip to content

Commit 18c3d18

Browse files
committed
[SPARK-51729] Support head/tail for DataFrame
### What changes were proposed in this pull request? This PR aims to support `head/tail` API for `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs and do the manual test. ``` $ swift test --filter DataFrameTests.head ... 􀟈 Test head() started. 􁁛 Test head() passed after 0.232 seconds. 􁁛 Suite DataFrameTests passed after 0.232 seconds. 􁁛 Test run with 1 test passed after 0.232 seconds. $ swift test --filter DataFrameTests.tail ... 􀟈 Test tail() started. 􁁛 Test tail() passed after 0.229 seconds. 􁁛 Suite DataFrameTests passed after 0.229 seconds. 􁁛 Test run with 1 test passed after 0.229 seconds. ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43 from dongjoon-hyun/SPARK-51729. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent f81aba2 commit 18c3d18

File tree

4 files changed

+48
-0
lines changed

4 files changed

+48
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,21 @@ public actor DataFrame: Sendable {
297297
return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n))
298298
}
299299

300+
/// Returns the first `n` rows.
301+
/// - Parameter n: The number of rows. (default: 1)
302+
/// - Returns: ``[[String?]]``
303+
public func head(_ n: Int32 = 1) async throws -> [[String?]] {
304+
return try await limit(n).collect()
305+
}
306+
307+
/// Returns the last `n` rows.
308+
/// - Parameter n: The number of rows.
309+
/// - Returns: ``[[String?]]``
310+
public func tail(_ n: Int32) async throws -> [[String?]] {
311+
let lastN = DataFrame(spark:spark, plan: SparkConnectClient.getTail(self.plan.root, n))
312+
return try await lastN.collect()
313+
}
314+
300315
/// Checks if the ``DataFrame`` is empty and returns a boolean value.
301316
/// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise.
302317
public func isEmpty() async throws -> Bool {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,17 @@ public actor SparkConnectClient {
373373
return plan
374374
}
375375

376+
static func getTail(_ child: Relation, _ n: Int32) -> Plan {
377+
var tail = Tail()
378+
tail.input = child
379+
tail.limit = n
380+
var relation = Relation()
381+
relation.tail = tail
382+
var plan = Plan()
383+
plan.opType = .root(relation)
384+
return plan
385+
}
386+
376387
var result: [ExecutePlanResponse] = []
377388
private func addResponse(_ response: ExecutePlanResponse) {
378389
self.result.append(response)

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
4141
typealias SparkConnectService = Spark_Connect_SparkConnectService
4242
typealias Sort = Spark_Connect_Sort
4343
typealias StructType = Spark_Connect_DataType.Struct
44+
typealias Tail = Spark_Connect_Tail
4445
typealias UserContext = Spark_Connect_UserContext
4546
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
4647
typealias WriteOperation = Spark_Connect_WriteOperation

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,27 @@ struct DataFrameTests {
247247
await spark.stop()
248248
}
249249

250+
@Test
251+
func head() async throws {
252+
let spark = try await SparkSession.builder.getOrCreate()
253+
#expect(try await spark.range(0).head().isEmpty)
254+
#expect(try await spark.range(2).sort("id").head() == [["0"]])
255+
#expect(try await spark.range(2).sort("id").head(1) == [["0"]])
256+
#expect(try await spark.range(2).sort("id").head(2) == [["0"], ["1"]])
257+
#expect(try await spark.range(2).sort("id").head(3) == [["0"], ["1"]])
258+
await spark.stop()
259+
}
260+
261+
@Test
262+
func tail() async throws {
263+
let spark = try await SparkSession.builder.getOrCreate()
264+
#expect(try await spark.range(0).tail(1).isEmpty)
265+
#expect(try await spark.range(2).sort("id").tail(1) == [["1"]])
266+
#expect(try await spark.range(2).sort("id").tail(2) == [["0"], ["1"]])
267+
#expect(try await spark.range(2).sort("id").tail(3) == [["0"], ["1"]])
268+
await spark.stop()
269+
}
270+
250271
@Test
251272
func show() async throws {
252273
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)