Skip to content

Commit 1169bea

Browse files
committed
[SPARK-51851] Refactor to use withGPRC wrappers
### What changes were proposed in this pull request? This PR aims to refactor to use `withGRPC` wrappers in `SparkConnectClient` and `DataFrame`. ### Why are the changes needed? This is helpful not only for reducing the duplicated code but also introducing a security feature. ### Does this PR introduce _any_ user-facing change? No, this is an internal code refactoring. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. This patch had conflicts when merged, resolved by Committer: Dongjoon Hyun <[email protected]> Closes #73 from dongjoon-hyun/SPARK-51851. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 9875380 commit 1169bea

File tree

2 files changed

+43
-112
lines changed

2 files changed

+43
-112
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 22 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,24 @@ public actor DataFrame: Sendable {
110110
}
111111
}
112112

113-
private func analyzePlanIfNeeded() async throws {
114-
if self._schema != nil {
115-
return
116-
}
113+
private func withGPRC<Result: Sendable>(
114+
_ f: (GRPCClient<GRPCNIOTransportHTTP2.HTTP2ClientTransport.Posix>) async throws -> Result
115+
) async throws -> Result {
117116
try await withGRPCClient(
118117
transport: .http2NIOPosix(
119118
target: .dns(host: spark.client.host, port: spark.client.port),
120119
transportSecurity: .plaintext
121120
)
122121
) { client in
122+
return try await f(client)
123+
}
124+
}
125+
126+
private func analyzePlanIfNeeded() async throws {
127+
if self._schema != nil {
128+
return
129+
}
130+
try await withGPRC { client in
123131
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
124132
let response = try await service.analyzePlan(
125133
spark.client.getAnalyzePlanRequest(spark.sessionID, plan))
@@ -132,12 +140,7 @@ public actor DataFrame: Sendable {
132140
public func count() async throws -> Int64 {
133141
let counter = Atomic(Int64(0))
134142

135-
try await withGRPCClient(
136-
transport: .http2NIOPosix(
137-
target: .dns(host: spark.client.host, port: spark.client.port),
138-
transportSecurity: .plaintext
139-
)
140-
) { client in
143+
try await withGPRC { client in
141144
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
142145
try await service.executePlan(spark.client.getExecutePlanRequest(plan)) {
143146
response in
@@ -154,12 +157,7 @@ public actor DataFrame: Sendable {
154157
// Clear all existing batches.
155158
self.batches.removeAll()
156159

157-
try await withGRPCClient(
158-
transport: .http2NIOPosix(
159-
target: .dns(host: spark.client.host, port: spark.client.port),
160-
transportSecurity: .plaintext
161-
)
162-
) { client in
160+
try await withGPRC { client in
163161
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
164162
try await service.executePlan(spark.client.getExecutePlanRequest(plan)) {
165163
response in
@@ -397,12 +395,7 @@ public actor DataFrame: Sendable {
397395
/// (without any Spark executors).
398396
/// - Returns: True if the plan is local.
399397
public func isLocal() async throws -> Bool {
400-
try await withGRPCClient(
401-
transport: .http2NIOPosix(
402-
target: .dns(host: spark.client.host, port: spark.client.port),
403-
transportSecurity: .plaintext
404-
)
405-
) { client in
398+
try await withGPRC { client in
406399
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
407400
let response = try await service.analyzePlan(spark.client.getIsLocal(spark.sessionID, plan))
408401
return response.isLocal.isLocal
@@ -413,12 +406,7 @@ public actor DataFrame: Sendable {
413406
/// arrives.
414407
/// - Returns: True if a plan is streaming.
415408
public func isStreaming() async throws -> Bool {
416-
try await withGRPCClient(
417-
transport: .http2NIOPosix(
418-
target: .dns(host: spark.client.host, port: spark.client.port),
419-
transportSecurity: .plaintext
420-
)
421-
) { client in
409+
try await withGPRC { client in
422410
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
423411
let response = try await service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan))
424412
return response.isStreaming.isStreaming
@@ -442,12 +430,7 @@ public actor DataFrame: Sendable {
442430
public func persist(storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) async throws
443431
-> DataFrame
444432
{
445-
try await withGRPCClient(
446-
transport: .http2NIOPosix(
447-
target: .dns(host: spark.client.host, port: spark.client.port),
448-
transportSecurity: .plaintext
449-
)
450-
) { client in
433+
try await withGPRC { client in
451434
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
452435
_ = try await service.analyzePlan(
453436
spark.client.getPersist(spark.sessionID, plan, storageLevel))
@@ -461,12 +444,7 @@ public actor DataFrame: Sendable {
461444
/// - Parameter blocking: Whether to block until all blocks are deleted.
462445
/// - Returns: A `DataFrame`
463446
public func unpersist(blocking: Bool = false) async throws -> DataFrame {
464-
try await withGRPCClient(
465-
transport: .http2NIOPosix(
466-
target: .dns(host: spark.client.host, port: spark.client.port),
467-
transportSecurity: .plaintext
468-
)
469-
) { client in
447+
try await withGPRC { client in
470448
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
471449
_ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking))
472450
}
@@ -476,12 +454,7 @@ public actor DataFrame: Sendable {
476454

477455
public var storageLevel: StorageLevel {
478456
get async throws {
479-
try await withGRPCClient(
480-
transport: .http2NIOPosix(
481-
target: .dns(host: spark.client.host, port: spark.client.port),
482-
transportSecurity: .plaintext
483-
)
484-
) { client in
457+
try await withGPRC { client in
485458
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
486459
return try await service
487460
.analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel.storageLevel.toStorageLevel
@@ -508,12 +481,7 @@ public actor DataFrame: Sendable {
508481
/// - Parameter mode: the expected output format of plans;
509482
/// `simple`, `extended`, `codegen`, `cost`, `formatted`.
510483
public func explain(_ mode: String) async throws {
511-
try await withGRPCClient(
512-
transport: .http2NIOPosix(
513-
target: .dns(host: spark.client.host, port: spark.client.port),
514-
transportSecurity: .plaintext
515-
)
516-
) { client in
484+
try await withGPRC { client in
517485
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
518486
let response = try await service.analyzePlan(spark.client.getExplain(spark.sessionID, plan, mode))
519487
print(response.explain.explainString)
@@ -525,12 +493,7 @@ public actor DataFrame: Sendable {
525493
/// results. Depending on the source relations, this may not find all input files. Duplicates are removed.
526494
/// - Returns: An array of file path strings.
527495
public func inputFiles() async throws -> [String] {
528-
try await withGRPCClient(
529-
transport: .http2NIOPosix(
530-
target: .dns(host: spark.client.host, port: spark.client.port),
531-
transportSecurity: .plaintext
532-
)
533-
) { client in
496+
try await withGPRC { client in
534497
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
535498
let response = try await service.analyzePlan(spark.client.getInputFiles(spark.sessionID, plan))
536499
return response.inputFiles.files
@@ -545,12 +508,7 @@ public actor DataFrame: Sendable {
545508
/// Prints the schema up to the given level to the console in a nice tree format.
546509
/// - Parameter level: A level to be printed.
547510
public func printSchema(_ level: Int32) async throws {
548-
try await withGRPCClient(
549-
transport: .http2NIOPosix(
550-
target: .dns(host: spark.client.host, port: spark.client.port),
551-
transportSecurity: .plaintext
552-
)
553-
) { client in
511+
try await withGPRC { client in
554512
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
555513
let response = try await service.analyzePlan(spark.client.getTreeString(spark.sessionID, plan, level))
556514
print(response.treeString.treeString)

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,7 @@ public actor SparkConnectClient {
5252
/// - Parameter sessionID: A string for the session ID.
5353
/// - Returns: An `AnalyzePlanResponse` instance for `SparkVersion`
5454
func connect(_ sessionID: String) async throws -> AnalyzePlanResponse {
55-
try await withGRPCClient(
56-
transport: .http2NIOPosix(
57-
target: .dns(host: self.host, port: self.port),
58-
transportSecurity: .plaintext
59-
)
60-
) { client in
55+
try await withGPRC { client in
6156
// To prevent server-side `INVALID_HANDLE.FORMAT (SQLSTATE: HY000)` exception.
6257
if UUID(uuidString: sessionID) == nil {
6358
throw SparkConnectError.InvalidSessionIDException
@@ -73,6 +68,19 @@ public actor SparkConnectClient {
7368
}
7469
}
7570

71+
private func withGPRC<Result: Sendable>(
72+
_ f: (GRPCClient<GRPCNIOTransportHTTP2.HTTP2ClientTransport.Posix>) async throws -> Result
73+
) async throws -> Result {
74+
try await withGRPCClient(
75+
transport: .http2NIOPosix(
76+
target: .dns(host: self.host, port: self.port),
77+
transportSecurity: .plaintext
78+
)
79+
) { client in
80+
return try await f(client)
81+
}
82+
}
83+
7684
/// Create a ``ConfigRequest`` instance for `Set` operation.
7785
/// - Parameter map: A map of key-value string pairs.
7886
/// - Returns: A ``ConfigRequest`` instance.
@@ -89,12 +97,7 @@ public actor SparkConnectClient {
8997
/// - Parameter map: A map of key-value pairs to set.
9098
/// - Returns: Always return true.
9199
func setConf(map: [String: String]) async throws -> Bool {
92-
try await withGRPCClient(
93-
transport: .http2NIOPosix(
94-
target: .dns(host: self.host, port: self.port),
95-
transportSecurity: .plaintext
96-
)
97-
) { client in
100+
try await withGPRC { client in
98101
let service = SparkConnectService.Client(wrapping: client)
99102
var request = getConfigRequestSet(map: map)
100103
request.clientType = clientType
@@ -118,12 +121,7 @@ public actor SparkConnectClient {
118121
}
119122

120123
func unsetConf(keys: [String]) async throws -> Bool {
121-
try await withGRPCClient(
122-
transport: .http2NIOPosix(
123-
target: .dns(host: self.host, port: self.port),
124-
transportSecurity: .plaintext
125-
)
126-
) { client in
124+
try await withGPRC { client in
127125
let service = SparkConnectService.Client(wrapping: client)
128126
var request = getConfigRequestUnset(keys: keys)
129127
request.clientType = clientType
@@ -150,12 +148,7 @@ public actor SparkConnectClient {
150148
/// - Parameter key: A string for key to look up.
151149
/// - Returns: A string for the value of the key.
152150
func getConf(_ key: String) async throws -> String {
153-
try await withGRPCClient(
154-
transport: .http2NIOPosix(
155-
target: .dns(host: self.host, port: self.port),
156-
transportSecurity: .plaintext
157-
)
158-
) { client in
151+
try await withGPRC { client in
159152
let service = SparkConnectService.Client(wrapping: client)
160153
var request = getConfigRequestGet(keys: [key])
161154
request.clientType = clientType
@@ -179,12 +172,7 @@ public actor SparkConnectClient {
179172
/// Request the server to get all configurations.
180173
/// - Returns: A map of key-value pairs.
181174
func getConfAll() async throws -> [String: String] {
182-
try await withGRPCClient(
183-
transport: .http2NIOPosix(
184-
target: .dns(host: self.host, port: self.port),
185-
transportSecurity: .plaintext
186-
)
187-
) { client in
175+
try await withGPRC { client in
188176
let service = SparkConnectService.Client(wrapping: client)
189177
var request = getConfigRequestGetAll()
190178
request.clientType = clientType
@@ -451,12 +439,7 @@ public actor SparkConnectClient {
451439

452440
func execute(_ sessionID: String, _ command: Command) async throws -> [ExecutePlanResponse] {
453441
self.result.removeAll()
454-
try await withGRPCClient(
455-
transport: .http2NIOPosix(
456-
target: .dns(host: self.host, port: self.port),
457-
transportSecurity: .plaintext
458-
)
459-
) { client in
442+
try await withGPRC { client in
460443
let service = SparkConnectService.Client(wrapping: client)
461444
var plan = Plan()
462445
plan.opType = .command(command)
@@ -501,12 +484,7 @@ public actor SparkConnectClient {
501484
/// - Parameter ddlString: A string to parse.
502485
/// - Returns: A ``Spark_Connect_DataType`` instance.
503486
func ddlParse(_ ddlString: String) async throws -> Spark_Connect_DataType {
504-
try await withGRPCClient(
505-
transport: .http2NIOPosix(
506-
target: .dns(host: self.host, port: self.port),
507-
transportSecurity: .plaintext
508-
)
509-
) { client in
487+
try await withGPRC { client in
510488
let service = SparkConnectService.Client(wrapping: client)
511489
let request = analyze(self.sessionID!, {
512490
var ddlParse = AnalyzePlanRequest.DDLParse()
@@ -522,12 +500,7 @@ public actor SparkConnectClient {
522500
/// - Parameter jsonString: A JSON string.
523501
/// - Returns: A DDL string.
524502
func jsonToDdl(_ jsonString: String) async throws -> String {
525-
try await withGRPCClient(
526-
transport: .http2NIOPosix(
527-
target: .dns(host: self.host, port: self.port),
528-
transportSecurity: .plaintext
529-
)
530-
) { client in
503+
try await withGPRC { client in
531504
let service = SparkConnectService.Client(wrapping: client)
532505
let request = analyze(self.sessionID!, {
533506
var jsonToDDL = AnalyzePlanRequest.JsonToDDL()

0 commit comments

Comments
 (0)