diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 0c93234..be5c5e7 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -330,6 +330,13 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n)) } + /// Returns a new Dataset by skipping the first `n` rows. + /// - Parameter n: Number of rows to skip. + /// - Returns: A subset of the rows + public func offset(_ n: Int32) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getOffset(self.plan.root, n)) + } + /// Returns a new ``Dataset`` by sampling a fraction of rows, using a user-supplied seed. /// - Parameters: /// - withReplacement: Sample with replacement or not. diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index d76f533..904e76e 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -397,6 +397,17 @@ public actor SparkConnectClient { return plan } + static func getOffset(_ child: Relation, _ n: Int32) -> Plan { + var offset = Spark_Connect_Offset() + offset.input = child + offset.offset = n + var relation = Relation() + relation.offset = offset + var plan = Plan() + plan.opType = .root(relation) + return plan + } + static func getSample(_ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> Plan { var sample = Sample() sample.input = child diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index aee0c93..b9c927d 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -218,6 +218,16 @@ struct DataFrameTests { await spark.stop() } + @Test + func offset() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).offset(0).count() == 10) + #expect(try await spark.range(10).offset(1).count() == 9) + #expect(try await spark.range(10).offset(2).count() == 8) + #expect(try await spark.range(10).offset(15).count() == 0) + await spark.stop() + } + @Test func sample() async throws { let spark = try await SparkSession.builder.getOrCreate()