Skip to content

Commit bfface7

Browse files
committed
[SPARK-51851] Refactor to use withGPRC wrappers
1 parent cb08c76 commit bfface7

File tree

2 files changed

+43
-112
lines changed

2 files changed

+43
-112
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 22 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,24 @@ public actor DataFrame: Sendable {
110110
}
111111
}
112112

113-
private func analyzePlanIfNeeded() async throws {
114-
if self._schema != nil {
115-
return
116-
}
113+
private func withGPRC<Result: Sendable>(
114+
_ f: (GRPCClient<GRPCNIOTransportHTTP2.HTTP2ClientTransport.Posix>) async throws -> Result
115+
) async throws -> Result {
117116
try await withGRPCClient(
118117
transport: .http2NIOPosix(
119118
target: .dns(host: spark.client.host, port: spark.client.port),
120119
transportSecurity: .plaintext
121120
)
122121
) { client in
122+
return try await f(client)
123+
}
124+
}
125+
126+
private func analyzePlanIfNeeded() async throws {
127+
if self._schema != nil {
128+
return
129+
}
130+
try await withGPRC { client in
123131
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
124132
let response = try await service.analyzePlan(
125133
spark.client.getAnalyzePlanRequest(spark.sessionID, plan))
@@ -132,12 +140,7 @@ public actor DataFrame: Sendable {
132140
public func count() async throws -> Int64 {
133141
let counter = Atomic(Int64(0))
134142

135-
try await withGRPCClient(
136-
transport: .http2NIOPosix(
137-
target: .dns(host: spark.client.host, port: spark.client.port),
138-
transportSecurity: .plaintext
139-
)
140-
) { client in
143+
try await withGPRC { client in
141144
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
142145
try await service.executePlan(spark.client.getExecutePlanRequest(plan)) {
143146
response in
@@ -151,12 +154,7 @@ public actor DataFrame: Sendable {
151154

152155
/// Execute the plan and try to fill `schema` and `batches`.
153156
private func execute() async throws {
154-
try await withGRPCClient(
155-
transport: .http2NIOPosix(
156-
target: .dns(host: spark.client.host, port: spark.client.port),
157-
transportSecurity: .plaintext
158-
)
159-
) { client in
157+
try await withGPRC { client in
160158
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
161159
try await service.executePlan(spark.client.getExecutePlanRequest(plan)) {
162160
response in
@@ -394,12 +392,7 @@ public actor DataFrame: Sendable {
394392
/// (without any Spark executors).
395393
/// - Returns: True if the plan is local.
396394
public func isLocal() async throws -> Bool {
397-
try await withGRPCClient(
398-
transport: .http2NIOPosix(
399-
target: .dns(host: spark.client.host, port: spark.client.port),
400-
transportSecurity: .plaintext
401-
)
402-
) { client in
395+
try await withGPRC { client in
403396
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
404397
let response = try await service.analyzePlan(spark.client.getIsLocal(spark.sessionID, plan))
405398
return response.isLocal.isLocal
@@ -410,12 +403,7 @@ public actor DataFrame: Sendable {
410403
/// arrives.
411404
/// - Returns: True if a plan is streaming.
412405
public func isStreaming() async throws -> Bool {
413-
try await withGRPCClient(
414-
transport: .http2NIOPosix(
415-
target: .dns(host: spark.client.host, port: spark.client.port),
416-
transportSecurity: .plaintext
417-
)
418-
) { client in
406+
try await withGPRC { client in
419407
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
420408
let response = try await service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan))
421409
return response.isStreaming.isStreaming
@@ -439,12 +427,7 @@ public actor DataFrame: Sendable {
439427
public func persist(storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) async throws
440428
-> DataFrame
441429
{
442-
try await withGRPCClient(
443-
transport: .http2NIOPosix(
444-
target: .dns(host: spark.client.host, port: spark.client.port),
445-
transportSecurity: .plaintext
446-
)
447-
) { client in
430+
try await withGPRC { client in
448431
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
449432
_ = try await service.analyzePlan(
450433
spark.client.getPersist(spark.sessionID, plan, storageLevel))
@@ -458,12 +441,7 @@ public actor DataFrame: Sendable {
458441
/// - Parameter blocking: Whether to block until all blocks are deleted.
459442
/// - Returns: A `DataFrame`
460443
public func unpersist(blocking: Bool = false) async throws -> DataFrame {
461-
try await withGRPCClient(
462-
transport: .http2NIOPosix(
463-
target: .dns(host: spark.client.host, port: spark.client.port),
464-
transportSecurity: .plaintext
465-
)
466-
) { client in
444+
try await withGPRC { client in
467445
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
468446
_ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking))
469447
}
@@ -473,12 +451,7 @@ public actor DataFrame: Sendable {
473451

474452
public var storageLevel: StorageLevel {
475453
get async throws {
476-
try await withGRPCClient(
477-
transport: .http2NIOPosix(
478-
target: .dns(host: spark.client.host, port: spark.client.port),
479-
transportSecurity: .plaintext
480-
)
481-
) { client in
454+
try await withGPRC { client in
482455
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
483456
return try await service
484457
.analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel.storageLevel.toStorageLevel
@@ -505,12 +478,7 @@ public actor DataFrame: Sendable {
505478
/// - Parameter mode: the expected output format of plans;
506479
/// `simple`, `extended`, `codegen`, `cost`, `formatted`.
507480
public func explain(_ mode: String) async throws {
508-
try await withGRPCClient(
509-
transport: .http2NIOPosix(
510-
target: .dns(host: spark.client.host, port: spark.client.port),
511-
transportSecurity: .plaintext
512-
)
513-
) { client in
481+
try await withGPRC { client in
514482
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
515483
let response = try await service.analyzePlan(spark.client.getExplain(spark.sessionID, plan, mode))
516484
print(response.explain.explainString)
@@ -522,12 +490,7 @@ public actor DataFrame: Sendable {
522490
/// results. Depending on the source relations, this may not find all input files. Duplicates are removed.
523491
/// - Returns: An array of file path strings.
524492
public func inputFiles() async throws -> [String] {
525-
try await withGRPCClient(
526-
transport: .http2NIOPosix(
527-
target: .dns(host: spark.client.host, port: spark.client.port),
528-
transportSecurity: .plaintext
529-
)
530-
) { client in
493+
try await withGPRC { client in
531494
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
532495
let response = try await service.analyzePlan(spark.client.getInputFiles(spark.sessionID, plan))
533496
return response.inputFiles.files
@@ -542,12 +505,7 @@ public actor DataFrame: Sendable {
542505
/// Prints the schema up to the given level to the console in a nice tree format.
543506
/// - Parameter level: A level to be printed.
544507
public func printSchema(_ level: Int32) async throws {
545-
try await withGRPCClient(
546-
transport: .http2NIOPosix(
547-
target: .dns(host: spark.client.host, port: spark.client.port),
548-
transportSecurity: .plaintext
549-
)
550-
) { client in
508+
try await withGPRC { client in
551509
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
552510
let response = try await service.analyzePlan(spark.client.getTreeString(spark.sessionID, plan, level))
553511
print(response.treeString.treeString)

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,7 @@ public actor SparkConnectClient {
5252
/// - Parameter sessionID: A string for the session ID.
5353
/// - Returns: An `AnalyzePlanResponse` instance for `SparkVersion`
5454
func connect(_ sessionID: String) async throws -> AnalyzePlanResponse {
55-
try await withGRPCClient(
56-
transport: .http2NIOPosix(
57-
target: .dns(host: self.host, port: self.port),
58-
transportSecurity: .plaintext
59-
)
60-
) { client in
55+
try await withGPRC { client in
6156
// To prevent server-side `INVALID_HANDLE.FORMAT (SQLSTATE: HY000)` exception.
6257
if UUID(uuidString: sessionID) == nil {
6358
throw SparkConnectError.InvalidSessionIDException
@@ -73,6 +68,19 @@ public actor SparkConnectClient {
7368
}
7469
}
7570

71+
private func withGPRC<Result: Sendable>(
72+
_ f: (GRPCClient<GRPCNIOTransportHTTP2.HTTP2ClientTransport.Posix>) async throws -> Result
73+
) async throws -> Result {
74+
try await withGRPCClient(
75+
transport: .http2NIOPosix(
76+
target: .dns(host: self.host, port: self.port),
77+
transportSecurity: .plaintext
78+
)
79+
) { client in
80+
return try await f(client)
81+
}
82+
}
83+
7684
/// Create a ``ConfigRequest`` instance for `Set` operation.
7785
/// - Parameter map: A map of key-value string pairs.
7886
/// - Returns: A ``ConfigRequest`` instance.
@@ -89,12 +97,7 @@ public actor SparkConnectClient {
8997
/// - Parameter map: A map of key-value pairs to set.
9098
/// - Returns: Always return true.
9199
func setConf(map: [String: String]) async throws -> Bool {
92-
try await withGRPCClient(
93-
transport: .http2NIOPosix(
94-
target: .dns(host: self.host, port: self.port),
95-
transportSecurity: .plaintext
96-
)
97-
) { client in
100+
try await withGPRC { client in
98101
let service = SparkConnectService.Client(wrapping: client)
99102
var request = getConfigRequestSet(map: map)
100103
request.clientType = clientType
@@ -118,12 +121,7 @@ public actor SparkConnectClient {
118121
}
119122

120123
func unsetConf(keys: [String]) async throws -> Bool {
121-
try await withGRPCClient(
122-
transport: .http2NIOPosix(
123-
target: .dns(host: self.host, port: self.port),
124-
transportSecurity: .plaintext
125-
)
126-
) { client in
124+
try await withGPRC { client in
127125
let service = SparkConnectService.Client(wrapping: client)
128126
var request = getConfigRequestUnset(keys: keys)
129127
request.clientType = clientType
@@ -150,12 +148,7 @@ public actor SparkConnectClient {
150148
/// - Parameter key: A string for key to look up.
151149
/// - Returns: A string for the value of the key.
152150
func getConf(_ key: String) async throws -> String {
153-
try await withGRPCClient(
154-
transport: .http2NIOPosix(
155-
target: .dns(host: self.host, port: self.port),
156-
transportSecurity: .plaintext
157-
)
158-
) { client in
151+
try await withGPRC { client in
159152
let service = SparkConnectService.Client(wrapping: client)
160153
var request = getConfigRequestGet(keys: [key])
161154
request.clientType = clientType
@@ -179,12 +172,7 @@ public actor SparkConnectClient {
179172
/// Request the server to get all configurations.
180173
/// - Returns: A map of key-value pairs.
181174
func getConfAll() async throws -> [String: String] {
182-
try await withGRPCClient(
183-
transport: .http2NIOPosix(
184-
target: .dns(host: self.host, port: self.port),
185-
transportSecurity: .plaintext
186-
)
187-
) { client in
175+
try await withGPRC { client in
188176
let service = SparkConnectService.Client(wrapping: client)
189177
var request = getConfigRequestGetAll()
190178
request.clientType = clientType
@@ -451,12 +439,7 @@ public actor SparkConnectClient {
451439

452440
func execute(_ sessionID: String, _ command: Command) async throws -> [ExecutePlanResponse] {
453441
self.result.removeAll()
454-
try await withGRPCClient(
455-
transport: .http2NIOPosix(
456-
target: .dns(host: self.host, port: self.port),
457-
transportSecurity: .plaintext
458-
)
459-
) { client in
442+
try await withGPRC { client in
460443
let service = SparkConnectService.Client(wrapping: client)
461444
var plan = Plan()
462445
plan.opType = .command(command)
@@ -501,12 +484,7 @@ public actor SparkConnectClient {
501484
/// - Parameter ddlString: A string to parse.
502485
/// - Returns: A ``Spark_Connect_DataType`` instance.
503486
func ddlParse(_ ddlString: String) async throws -> Spark_Connect_DataType {
504-
try await withGRPCClient(
505-
transport: .http2NIOPosix(
506-
target: .dns(host: self.host, port: self.port),
507-
transportSecurity: .plaintext
508-
)
509-
) { client in
487+
try await withGPRC { client in
510488
let service = SparkConnectService.Client(wrapping: client)
511489
let request = analyze(self.sessionID!, {
512490
var ddlParse = AnalyzePlanRequest.DDLParse()
@@ -522,12 +500,7 @@ public actor SparkConnectClient {
522500
/// - Parameter jsonString: A JSON string.
523501
/// - Returns: A DDL string.
524502
func jsonToDdl(_ jsonString: String) async throws -> String {
525-
try await withGRPCClient(
526-
transport: .http2NIOPosix(
527-
target: .dns(host: self.host, port: self.port),
528-
transportSecurity: .plaintext
529-
)
530-
) { client in
503+
try await withGPRC { client in
531504
let service = SparkConnectService.Client(wrapping: client)
532505
let request = analyze(self.sessionID!, {
533506
var jsonToDDL = AnalyzePlanRequest.JsonToDDL()

0 commit comments

Comments
 (0)