Skip to content

Commit b87ad0f

Browse files
committed
[SPARK-51804] Support sample in DataFrame
1 parent 7ab1e45 commit b87ad0f

File tree

4 files changed

+41
-0
lines changed

4 files changed

+41
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ public actor DataFrame: Sendable {
297297
return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n))
298298
}
299299

300+
public func sample(_ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> DataFrame {
301+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSample(self.plan.root, withReplacement, fraction, seed))
302+
}
303+
304+
public func sample(_ withReplacement: Bool, _ fraction: Double) -> DataFrame {
305+
return sample(withReplacement, fraction, Int64.random(in: Int64.min...Int64.max))
306+
}
307+
308+
public func sample(_ fraction: Double, _ seed: Int64) -> DataFrame {
309+
return sample(false, fraction, seed)
310+
}
311+
312+
public func sample(_ fraction: Double) -> DataFrame {
313+
return sample(false, fraction)
314+
}
315+
316+
300317
/// Returns the first `n` rows.
301318
/// - Parameter n: The number of rows. (default: 1)
302319
/// - Returns: ``[[String?]]``

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ public actor SparkConnectClient {
375375
return plan
376376
}
377377

378+
static func getSample(_ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> Plan {
379+
var sample = Sample()
380+
sample.input = child
381+
sample.withReplacement = withReplacement
382+
sample.lowerBound = 0.0
383+
sample.upperBound = fraction
384+
sample.seed = seed
385+
var relation = Relation()
386+
relation.sample = sample
387+
var plan = Plan()
388+
plan.opType = .root(relation)
389+
return plan
390+
}
391+
378392
static func getTail(_ child: Relation, _ n: Int32) -> Plan {
379393
var tail = Tail()
380394
tail.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ typealias Project = Spark_Connect_Project
3939
typealias Range = Spark_Connect_Range
4040
typealias Read = Spark_Connect_Read
4141
typealias Relation = Spark_Connect_Relation
42+
typealias Sample = Spark_Connect_Sample
4243
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
4344
typealias SparkConnectService = Spark_Connect_SparkConnectService
4445
typealias Sort = Spark_Connect_Sort

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,15 @@ struct DataFrameTests {
196196
await spark.stop()
197197
}
198198

199+
@Test
200+
func sample() async throws {
201+
let spark = try await SparkSession.builder.getOrCreate()
202+
#expect(try await spark.range(100000).sample(0.001).count() < 1000)
203+
#expect(try await spark.range(100000).sample(0.999).count() > 99000)
204+
#expect(try await spark.range(100000).sample(true, 0.001, 0).count() < 1000)
205+
await spark.stop()
206+
}
207+
199208
@Test
200209
func isEmpty() async throws {
201210
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)