Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,45 @@ public actor DataFrame: Sendable {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n))
}

/// Returns a new ``Dataset`` by sampling a fraction of rows, using a user-supplied seed.
/// - Parameters:
/// - withReplacement: Sample with replacement or not.
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
/// - seed: Seed for sampling.
/// - Returns: A subset of the records.
public func sample(_ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSample(self.plan.root, withReplacement, fraction, seed))
}

/// Returns a new ``Dataset`` by sampling a fraction of rows, using a random seed.
/// - Parameters:
/// - withReplacement: Sample with replacement or not.
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
/// - seed: Seed for sampling.
/// - Returns: A subset of the records.
public func sample(_ withReplacement: Bool, _ fraction: Double) -> DataFrame {
return sample(withReplacement, fraction, Int64.random(in: Int64.min...Int64.max))
}

/// Returns a new ``Dataset`` by sampling a fraction of rows (without replacement), using a
/// user-supplied seed.
/// - Parameters:
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
/// - seed: Seed for sampling.
/// - Returns: A subset of the records.
public func sample(_ fraction: Double, _ seed: Int64) -> DataFrame {
return sample(false, fraction, seed)
}

/// Returns a new ``Dataset`` by sampling a fraction of rows (without replacement), using a
/// random seed.
/// - Parameters:
/// - fraction: Fraction of rows to generate, range [0.0, 1.0].
/// - Returns: A subset of the records.
public func sample(_ fraction: Double) -> DataFrame {
return sample(false, fraction)
}

/// Returns the first `n` rows.
/// - Parameter n: The number of rows. (default: 1)
/// - Returns: ``[[String?]]``
Expand Down
14 changes: 14 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ public actor SparkConnectClient {
return plan
}

static func getSample(_ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> Plan {
var sample = Sample()
sample.input = child
sample.withReplacement = withReplacement
sample.lowerBound = 0.0
sample.upperBound = fraction
sample.seed = seed
var relation = Relation()
relation.sample = sample
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getTail(_ child: Relation, _ n: Int32) -> Plan {
var tail = Tail()
tail.input = child
Expand Down
1 change: 1 addition & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ typealias Project = Spark_Connect_Project
typealias Range = Spark_Connect_Range
typealias Read = Spark_Connect_Read
typealias Relation = Spark_Connect_Relation
typealias Sample = Spark_Connect_Sample
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
Expand Down
9 changes: 9 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func sample() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(100000).sample(0.001).count() < 1000)
#expect(try await spark.range(100000).sample(0.999).count() > 99000)
#expect(try await spark.range(100000).sample(true, 0.001, 0).count() < 1000)
await spark.stop()
}

@Test
func isEmpty() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading