Skip to content

Commit 7de35f7

Browse files
committed
[SPARK-52064] Support first/take/toJSON in DataFrame
### What changes were proposed in this pull request? This PR aims to support `first/take/toJSON` APIs in `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test cases. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #122 from dongjoon-hyun/SPARK-52064. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent f6f4a2d commit 7de35f7

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ import Synchronization
8383
/// ### Data Collection
8484
/// - ``count()``
8585
/// - ``collect()``
86+
/// - ``first()``
87+
/// - ``head()``
8688
/// - ``head(_:)``
89+
/// - ``take(_:)``
8790
/// - ``tail(_:)``
8891
/// - ``show()``
8992
/// - ``show(_:)``
@@ -92,6 +95,7 @@ import Synchronization
9295
///
9396
/// ### Transformation Operations
9497
/// - ``toDF(_:)``
98+
/// - ``toJSON()``
9599
/// - ``select(_:)``
96100
/// - ``selectExpr(_:)``
97101
/// - ``filter(_:)``
@@ -467,6 +471,12 @@ public actor DataFrame: Sendable {
467471
return df
468472
}
469473

474+
/// Returns the content of the Dataset as a Dataset of JSON strings.
475+
/// - Returns: A ``DataFrame`` with a single string column whose content is JSON.
476+
public func toJSON() -> DataFrame {
477+
return selectExpr("to_json(struct(*))")
478+
}
479+
470480
/// Projects a set of expressions and returns a new ``DataFrame``.
471481
/// - Parameter exprs: Expression strings
472482
/// - Returns: A ``DataFrame`` with subset of columns.
@@ -685,13 +695,33 @@ public actor DataFrame: Sendable {
685695
/// let firstFive = try await df.head(5)
686696
/// ```
687697
///
688-
/// - Parameter n: Number of rows to return (default: 1)
698+
/// - Parameter n: Number of rows to return.
689699
/// - Returns: An array of ``Row`` objects
690700
/// - Throws: `SparkConnectError` if the operation fails
691-
public func head(_ n: Int32 = 1) async throws -> [Row] {
701+
public func head(_ n: Int32) async throws -> [Row] {
692702
return try await limit(n).collect()
693703
}
694704

705+
/// Returns the first row.
706+
/// - Returns: A ``Row``.
707+
public func head() async throws -> Row {
708+
return try await head(1)[0]
709+
}
710+
711+
/// Returns the first row. Alias for head().
712+
/// - Returns: A ``Row``.
713+
public func first() async throws -> Row {
714+
return try await head()
715+
}
716+
717+
/// Returns the first n rows.
718+
/// - Parameter n: Number of rows to return.
719+
/// - Returns: An array of ``Row`` objects
720+
/// - Throws: `SparkConnectError` if the operation fails
721+
public func take(_ n: Int32) async throws -> [Row] {
722+
return try await head(n)
723+
}
724+
695725
/// Returns the last `n` rows.
696726
/// - Parameter n: The number of rows.
697727
/// - Returns: ``[Row]``

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,18 +377,35 @@ struct DataFrameTests {
377377
await spark.stop()
378378
}
379379

380+
@Test
381+
func first() async throws {
382+
let spark = try await SparkSession.builder.getOrCreate()
383+
#expect(try await spark.range(2).sort("id").first() == Row(0))
384+
#expect(try await spark.range(2).sort("id").head() == Row(0))
385+
await spark.stop()
386+
}
387+
380388
@Test
381389
func head() async throws {
382390
let spark = try await SparkSession.builder.getOrCreate()
383-
#expect(try await spark.range(0).head().isEmpty)
384-
print(try await spark.range(2).sort("id").head())
385-
#expect(try await spark.range(2).sort("id").head() == [Row(0)])
391+
#expect(try await spark.range(0).head(1).isEmpty)
392+
#expect(try await spark.range(2).sort("id").head() == Row(0))
386393
#expect(try await spark.range(2).sort("id").head(1) == [Row(0)])
387394
#expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)])
388395
#expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)])
389396
await spark.stop()
390397
}
391398

399+
@Test
400+
func take() async throws {
401+
let spark = try await SparkSession.builder.getOrCreate()
402+
#expect(try await spark.range(0).take(1).isEmpty)
403+
#expect(try await spark.range(2).sort("id").take(1) == [Row(0)])
404+
#expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)])
405+
#expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)])
406+
await spark.stop()
407+
}
408+
392409
@Test
393410
func tail() async throws {
394411
let spark = try await SparkSession.builder.getOrCreate()
@@ -759,6 +776,18 @@ struct DataFrameTests {
759776
])
760777
await spark.stop()
761778
}
779+
780+
@Test
781+
func toJSON() async throws {
782+
let spark = try await SparkSession.builder.getOrCreate()
783+
let df = try await spark.range(2).toJSON()
784+
#expect(try await df.columns == ["to_json(struct(id))"])
785+
#expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")])
786+
787+
let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")]
788+
#expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected)
789+
await spark.stop()
790+
}
762791
#endif
763792

764793
@Test

0 commit comments

Comments
 (0)