diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index b1e831e..1121473 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -297,6 +297,21 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n)) } + /// Returns the first `n` rows. + /// - Parameter n: The number of rows. (default: 1) + /// - Returns: ``[[String?]]`` + public func head(_ n: Int32 = 1) async throws -> [[String?]] { + return try await limit(n).collect() + } + + /// Returns the last `n` rows. + /// - Parameter n: The number of rows. + /// - Returns: ``[[String?]]`` + public func tail(_ n: Int32) async throws -> [[String?]] { + let lastN = DataFrame(spark:spark, plan: SparkConnectClient.getTail(self.plan.root, n)) + return try await lastN.collect() + } + /// Checks if the ``DataFrame`` is empty and returns a boolean value. /// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise. public func isEmpty() async throws -> Bool { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 3314c55..4e14077 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -373,6 +373,17 @@ public actor SparkConnectClient { return plan } + static func getTail(_ child: Relation, _ n: Int32) -> Plan { + var tail = Tail() + tail.input = child + tail.limit = n + var relation = Relation() + relation.tail = tail + var plan = Plan() + plan.opType = .root(relation) + return plan + } + var result: [ExecutePlanResponse] = [] private func addResponse(_ response: ExecutePlanResponse) { self.result.append(response) diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index f82c4f5..198df89 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -41,6 +41,7 @@ typealias SaveMode = Spark_Connect_WriteOperation.SaveMode typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort typealias StructType = Spark_Connect_DataType.Struct +typealias Tail = Spark_Connect_Tail typealias UserContext = Spark_Connect_UserContext typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute typealias WriteOperation = Spark_Connect_WriteOperation diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 832b031..36d6084 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -247,6 +247,27 @@ struct DataFrameTests { await spark.stop() } + @Test + func head() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).head().isEmpty) + #expect(try await spark.range(2).sort("id").head() == [["0"]]) + #expect(try await spark.range(2).sort("id").head(1) == [["0"]]) + #expect(try await spark.range(2).sort("id").head(2) == [["0"], ["1"]]) + #expect(try await spark.range(2).sort("id").head(3) == [["0"], ["1"]]) + await spark.stop() + } + + @Test + func tail() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).tail(1).isEmpty) + #expect(try await spark.range(2).sort("id").tail(1) == [["1"]]) + #expect(try await spark.range(2).sort("id").tail(2) == [["0"], ["1"]]) + #expect(try await spark.range(2).sort("id").tail(3) == [["0"], ["1"]]) + await spark.stop() + } + @Test func show() async throws { let spark = try await SparkSession.builder.getOrCreate()