diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index f595d96..5771588 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -83,7 +83,10 @@ import Synchronization /// ### Data Collection /// - ``count()`` /// - ``collect()`` +/// - ``first()`` +/// - ``head()`` /// - ``head(_:)`` +/// - ``take(_:)`` /// - ``tail(_:)`` /// - ``show()`` /// - ``show(_:)`` @@ -92,6 +95,7 @@ import Synchronization /// /// ### Transformation Operations /// - ``toDF(_:)`` +/// - ``toJSON()`` /// - ``select(_:)`` /// - ``selectExpr(_:)`` /// - ``filter(_:)`` @@ -467,6 +471,12 @@ public actor DataFrame: Sendable { return df } + /// Returns the content of the Dataset as a Dataset of JSON strings. + /// - Returns: A ``DataFrame`` with a single string column whose content is JSON. + public func toJSON() -> DataFrame { + return selectExpr("to_json(struct(*))") + } + /// Projects a set of expressions and returns a new ``DataFrame``. /// - Parameter exprs: Expression strings /// - Returns: A ``DataFrame`` with subset of columns. @@ -685,13 +695,33 @@ public actor DataFrame: Sendable { /// let firstFive = try await df.head(5) /// ``` /// - /// - Parameter n: Number of rows to return (default: 1) + /// - Parameter n: Number of rows to return. /// - Returns: An array of ``Row`` objects /// - Throws: `SparkConnectError` if the operation fails - public func head(_ n: Int32 = 1) async throws -> [Row] { + public func head(_ n: Int32) async throws -> [Row] { return try await limit(n).collect() } + /// Returns the first row. + /// - Returns: A ``Row``. + public func head() async throws -> Row { + return try await head(1)[0] + } + + /// Returns the first row. Alias for head(). + /// - Returns: A ``Row``. + public func first() async throws -> Row { + return try await head() + } + + /// Returns the first n rows. + /// - Parameter n: Number of rows to return. + /// - Returns: An array of ``Row`` objects + /// - Throws: `SparkConnectError` if the operation fails + public func take(_ n: Int32) async throws -> [Row] { + return try await head(n) + } + /// Returns the last `n` rows. /// - Parameter n: The number of rows. /// - Returns: ``[Row]`` diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 693f371..bea34ed 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -377,18 +377,35 @@ struct DataFrameTests { await spark.stop() } + @Test + func first() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(2).sort("id").first() == Row(0)) + #expect(try await spark.range(2).sort("id").head() == Row(0)) + await spark.stop() + } + @Test func head() async throws { let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).head().isEmpty) - print(try await spark.range(2).sort("id").head()) - #expect(try await spark.range(2).sort("id").head() == [Row(0)]) + #expect(try await spark.range(0).head(1).isEmpty) + #expect(try await spark.range(2).sort("id").head() == Row(0)) #expect(try await spark.range(2).sort("id").head(1) == [Row(0)]) #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)]) #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)]) await spark.stop() } + @Test + func take() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).take(1).isEmpty) + #expect(try await spark.range(2).sort("id").take(1) == [Row(0)]) + #expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)]) + #expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)]) + await spark.stop() + } + @Test func tail() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -759,6 +776,18 @@ struct DataFrameTests { ]) await spark.stop() } + + @Test + func toJSON() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(2).toJSON() + #expect(try await df.columns == ["to_json(struct(id))"]) + #expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")]) + + let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")] + #expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected) + await spark.stop() + } #endif @Test