Skip to content

Commit 43714e0

Browse files
committed
[SPARK-51804] Support sample in DataFrame
### What changes were proposed in this pull request? This PR aims to support four `sample` APIs in `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #59 from dongjoon-hyun/SPARK-51804. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 7ab1e45 commit 43714e0

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,44 @@ public actor DataFrame: Sendable {
297297
return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n))
298298
}
299299

300+
/// Returns a new ``Dataset`` by sampling a fraction of rows, using a user-supplied seed.
301+
/// - Parameters:
302+
/// - withReplacement: Sample with replacement or not.
303+
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
304+
/// - seed: Seed for sampling.
305+
/// - Returns: A subset of the records.
306+
public func sample(_ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> DataFrame {
307+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSample(self.plan.root, withReplacement, fraction, seed))
308+
}
309+
310+
/// Returns a new ``Dataset`` by sampling a fraction of rows, using a random seed.
311+
/// - Parameters:
312+
/// - withReplacement: Sample with replacement or not.
313+
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
314+
/// - Returns: A subset of the records.
315+
public func sample(_ withReplacement: Bool, _ fraction: Double) -> DataFrame {
316+
return sample(withReplacement, fraction, Int64.random(in: Int64.min...Int64.max))
317+
}
318+
319+
/// Returns a new ``Dataset`` by sampling a fraction of rows (without replacement), using a
320+
/// user-supplied seed.
321+
/// - Parameters:
322+
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
323+
/// - seed: Seed for sampling.
324+
/// - Returns: A subset of the records.
325+
public func sample(_ fraction: Double, _ seed: Int64) -> DataFrame {
326+
return sample(false, fraction, seed)
327+
}
328+
329+
/// Returns a new ``Dataset`` by sampling a fraction of rows (without replacement), using a
330+
/// random seed.
331+
/// - Parameters:
332+
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
333+
/// - Returns: A subset of the records.
334+
public func sample(_ fraction: Double) -> DataFrame {
335+
return sample(false, fraction)
336+
}
337+
300338
/// Returns the first `n` rows.
301339
/// - Parameter n: The number of rows. (default: 1)
302340
/// - Returns: ``[[String?]]``

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ public actor SparkConnectClient {
375375
return plan
376376
}
377377

378+
static func getSample(_ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> Plan {
379+
var sample = Sample()
380+
sample.input = child
381+
sample.withReplacement = withReplacement
382+
sample.lowerBound = 0.0
383+
sample.upperBound = fraction
384+
sample.seed = seed
385+
var relation = Relation()
386+
relation.sample = sample
387+
var plan = Plan()
388+
plan.opType = .root(relation)
389+
return plan
390+
}
391+
378392
static func getTail(_ child: Relation, _ n: Int32) -> Plan {
379393
var tail = Tail()
380394
tail.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ typealias Project = Spark_Connect_Project
3939
typealias Range = Spark_Connect_Range
4040
typealias Read = Spark_Connect_Read
4141
typealias Relation = Spark_Connect_Relation
42+
typealias Sample = Spark_Connect_Sample
4243
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
4344
typealias SparkConnectService = Spark_Connect_SparkConnectService
4445
typealias Sort = Spark_Connect_Sort

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,15 @@ struct DataFrameTests {
196196
await spark.stop()
197197
}
198198

199+
@Test
200+
func sample() async throws {
201+
let spark = try await SparkSession.builder.getOrCreate()
202+
#expect(try await spark.range(100000).sample(0.001).count() < 1000)
203+
#expect(try await spark.range(100000).sample(0.999).count() > 99000)
204+
#expect(try await spark.range(100000).sample(true, 0.001, 0).count() < 1000)
205+
await spark.stop()
206+
}
207+
199208
@Test
200209
func isEmpty() async throws {
201210
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)