diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index df6e325..72263df 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -297,6 +297,44 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n)) } + /// Returns a new ``Dataset`` by sampling a fraction of rows, using a user-supplied seed. + /// - Parameters: + /// - withReplacement: Sample with replacement or not. + /// - fraction: Fraction of rows to generate, range [0.0, 1.0]. + /// - seed: Seed for sampling. + /// - Returns: A subset of the records. + public func sample(_ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getSample(self.plan.root, withReplacement, fraction, seed)) + } + + /// Returns a new ``Dataset`` by sampling a fraction of rows, using a random seed. + /// - Parameters: + /// - withReplacement: Sample with replacement or not. + /// - fraction: Fraction of rows to generate, range [0.0, 1.0]. + /// - Returns: A subset of the records. + public func sample(_ withReplacement: Bool, _ fraction: Double) -> DataFrame { + return sample(withReplacement, fraction, Int64.random(in: Int64.min...Int64.max)) + } + + /// Returns a new ``Dataset`` by sampling a fraction of rows (without replacement), using a + /// user-supplied seed. + /// - Parameters: + /// - fraction: Fraction of rows to generate, range [0.0, 1.0]. + /// - seed: Seed for sampling. + /// - Returns: A subset of the records. + public func sample(_ fraction: Double, _ seed: Int64) -> DataFrame { + return sample(false, fraction, seed) + } + + /// Returns a new ``Dataset`` by sampling a fraction of rows (without replacement), using a + /// random seed. + /// - Parameters: + /// - fraction: Fraction of rows to generate, range [0.0, 1.0]. + /// - Returns: A subset of the records. + public func sample(_ fraction: Double) -> DataFrame { + return sample(false, fraction) + } + /// Returns the first `n` rows. /// - Parameter n: The number of rows. (default: 1) /// - Returns: ``[[String?]]`` diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 1e8087c..2acbd6e 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -375,6 +375,20 @@ public actor SparkConnectClient { return plan } + static func getSample(_ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> Plan { + var sample = Sample() + sample.input = child + sample.withReplacement = withReplacement + sample.lowerBound = 0.0 + sample.upperBound = fraction + sample.seed = seed + var relation = Relation() + relation.sample = sample + var plan = Plan() + plan.opType = .root(relation) + return plan + } + static func getTail(_ child: Relation, _ n: Int32) -> Plan { var tail = Tail() tail.input = child diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 6a700d6..766ad02 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -39,6 +39,7 @@ typealias Project = Spark_Connect_Project typealias Range = Spark_Connect_Range typealias Read = Spark_Connect_Read typealias Relation = Spark_Connect_Relation +typealias Sample = Spark_Connect_Sample typealias SaveMode = Spark_Connect_WriteOperation.SaveMode typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 9ec515d..1e602c5 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -196,6 +196,15 @@ struct DataFrameTests { await spark.stop() } + @Test + func sample() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(100000).sample(0.001).count() < 1000) + #expect(try await spark.range(100000).sample(0.999).count() > 99000) + #expect(try await spark.range(100000).sample(true, 0.001, 0).count() < 1000) + await spark.stop() + } + @Test func isEmpty() async throws { let spark = try await SparkSession.builder.getOrCreate()