diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 66652e3..d373a5b 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -390,6 +390,38 @@ public actor DataFrame: Sendable { return try await lastN.collect() } + /// Returns true if the `collect` and `take` methods can be run locally + /// (without any Spark executors). + /// - Returns: True if the plan is local. + public func isLocal() async throws -> Bool { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + let response = try await service.analyzePlan(spark.client.getIsLocal(spark.sessionID, plan)) + return response.isLocal.isLocal + } + } + + /// Returns true if this `DataFrame` contains one or more sources that continuously return data as it + /// arrives. + /// - Returns: True if a plan is streaming. + public func isStreaming() async throws -> Bool { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + let response = try await service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan)) + return response.isStreaming.isStreaming + } + } + /// Checks if the ``DataFrame`` is empty and returns a boolean value. /// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise. public func isEmpty() async throws -> Bool { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index aa7320f..e058177 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -556,4 +556,24 @@ public actor SparkConnectClient { plan.opType = .root(relation) return plan } + + func getIsLocal(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest { + return analyze( + sessionID, + { + var isLocal = AnalyzePlanRequest.IsLocal() + isLocal.plan = plan + return OneOf_Analyze.isLocal(isLocal) + }) + } + + func getIsStreaming(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest { + return analyze( + sessionID, + { + var isStreaming = AnalyzePlanRequest.IsStreaming() + isStreaming.plan = plan + return OneOf_Analyze.isStreaming(isStreaming) + }) + } } diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 07443c3..1839f76 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -252,6 +252,22 @@ struct DataFrameTests { await spark.stop() } + @Test + func isLocal() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.sql("SHOW DATABASES").isLocal()) + #expect(try await spark.sql("SHOW TABLES").isLocal()) + #expect(try await spark.range(1).isLocal() == false) + await spark.stop() + } + + @Test + func isStreaming() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(1).isStreaming() == false) + await spark.stop() + } + #if !os(Linux) @Test func sort() async throws {