Skip to content

Commit 0af7e4a

Browse files
committed
[SPARK-51841] Support isLocal and isStreaming for DataFrame
### What changes were proposed in this pull request? This PR aims to support `isLocal` and `isStreaming` for `DataFrame`. ### Why are the changes needed? For feature parity. In addition, these APIs are required during other API implementations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #69 from dongjoon-hyun/SPARK-51841. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 1dde04c commit 0af7e4a

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,38 @@ public actor DataFrame: Sendable {
390390
return try await lastN.collect()
391391
}
392392

393+
/// Returns true if the `collect` and `take` methods can be run locally
394+
/// (without any Spark executors).
395+
/// - Returns: True if the plan is local.
396+
public func isLocal() async throws -> Bool {
397+
try await withGRPCClient(
398+
transport: .http2NIOPosix(
399+
target: .dns(host: spark.client.host, port: spark.client.port),
400+
transportSecurity: .plaintext
401+
)
402+
) { client in
403+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
404+
let response = try await service.analyzePlan(spark.client.getIsLocal(spark.sessionID, plan))
405+
return response.isLocal.isLocal
406+
}
407+
}
408+
409+
/// Returns true if this `DataFrame` contains one or more sources that continuously return data as it
410+
/// arrives.
411+
/// - Returns: True if a plan is streaming.
412+
public func isStreaming() async throws -> Bool {
413+
try await withGRPCClient(
414+
transport: .http2NIOPosix(
415+
target: .dns(host: spark.client.host, port: spark.client.port),
416+
transportSecurity: .plaintext
417+
)
418+
) { client in
419+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
420+
let response = try await service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan))
421+
return response.isStreaming.isStreaming
422+
}
423+
}
424+
393425
/// Checks if the ``DataFrame`` is empty and returns a boolean value.
394426
/// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise.
395427
public func isEmpty() async throws -> Bool {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,4 +556,24 @@ public actor SparkConnectClient {
556556
plan.opType = .root(relation)
557557
return plan
558558
}
559+
560+
func getIsLocal(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest {
561+
return analyze(
562+
sessionID,
563+
{
564+
var isLocal = AnalyzePlanRequest.IsLocal()
565+
isLocal.plan = plan
566+
return OneOf_Analyze.isLocal(isLocal)
567+
})
568+
}
569+
570+
func getIsStreaming(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest {
571+
return analyze(
572+
sessionID,
573+
{
574+
var isStreaming = AnalyzePlanRequest.IsStreaming()
575+
isStreaming.plan = plan
576+
return OneOf_Analyze.isStreaming(isStreaming)
577+
})
578+
}
559579
}

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,22 @@ struct DataFrameTests {
252252
await spark.stop()
253253
}
254254

255+
@Test
256+
func isLocal() async throws {
257+
let spark = try await SparkSession.builder.getOrCreate()
258+
#expect(try await spark.sql("SHOW DATABASES").isLocal())
259+
#expect(try await spark.sql("SHOW TABLES").isLocal())
260+
#expect(try await spark.range(1).isLocal() == false)
261+
await spark.stop()
262+
}
263+
264+
@Test
265+
func isStreaming() async throws {
266+
let spark = try await SparkSession.builder.getOrCreate()
267+
#expect(try await spark.range(1).isStreaming() == false)
268+
await spark.stop()
269+
}
270+
255271
#if !os(Linux)
256272
@Test
257273
func sort() async throws {

0 commit comments

Comments
 (0)