Skip to content

Commit d14b00c

Browse files
committed
[SPARK-52302] Improve stop to use ReleaseSessionRequest
1 parent 8600b46 commit d14b00c

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,17 @@ public actor SparkConnectClient {
8282
self.userContext = userName.toUserContext
8383
}
8484

85-
/// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet.
86-
func stop() {
85+
/// Stop the connection.
86+
func stop() async {
87+
guard self.sessionID != nil else { return }
88+
try? await withGPRC { client in
89+
let service = SparkConnectService.Client(wrapping: client)
90+
var request = Spark_Connect_ReleaseSessionRequest()
91+
request.sessionID = self.sessionID!
92+
request.userContext = self.userContext
93+
request.clientType = self.clientType
94+
_ = try await service.releaseSession(request)
95+
}
8796
}
8897

8998
/// Connect to the `Spark Connect` server with the given session ID string.

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import Testing
2727
struct SparkSessionTests {
2828
@Test
2929
func sparkContext() async throws {
30+
await SparkSession.builder.clear()
3031
let spark = try await SparkSession.builder.getOrCreate()
3132
await #expect(throws: SparkConnectError.UnsupportedOperationException) {
3233
try await spark.sparkContext
@@ -36,12 +37,14 @@ struct SparkSessionTests {
3637

3738
@Test
3839
func stop() async throws {
40+
await SparkSession.builder.clear()
3941
let spark = try await SparkSession.builder.getOrCreate()
4042
await spark.stop()
4143
}
4244

4345
@Test
4446
func newSession() async throws {
47+
await SparkSession.builder.clear()
4548
let spark = try await SparkSession.builder.getOrCreate()
4649
await spark.stop()
4750
let newSpark = try await spark.newSession()
@@ -52,16 +55,29 @@ struct SparkSessionTests {
5255

5356
@Test
5457
func sessionID() async throws {
58+
await SparkSession.builder.clear()
5559
let spark1 = try await SparkSession.builder.getOrCreate()
56-
await spark1.stop()
5760
let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost"
5861
let spark2 = try await SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)").getOrCreate()
5962
await spark2.stop()
6063
#expect(spark1.sessionID == spark2.sessionID)
6164
#expect(spark1 == spark2)
6265
}
6366

67+
@Test
68+
func closedSessionID() async throws {
69+
await SparkSession.builder.clear()
70+
let spark1 = try await SparkSession.builder.getOrCreate()
71+
let sessionID = spark1.sessionID
72+
await spark1.stop()
73+
let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost"
74+
try await #require(throws: Error.self) {
75+
try await SparkSession.builder.remote("\(remote)/;session_id=\(sessionID)").getOrCreate()
76+
}
77+
}
78+
6479
@Test func userContext() async throws {
80+
await SparkSession.builder.clear()
6581
let spark = try await SparkSession.builder.getOrCreate()
6682
#if os(macOS) || os(Linux)
6783
let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
@@ -74,6 +90,7 @@ struct SparkSessionTests {
7490

7591
@Test
7692
func version() async throws {
93+
await SparkSession.builder.clear()
7794
let spark = try await SparkSession.builder.getOrCreate()
7895
let version = await spark.version
7996
#expect(version.starts(with: "4.0.0") || version.starts(with: "3.5."))
@@ -82,6 +99,7 @@ struct SparkSessionTests {
8299

83100
@Test
84101
func conf() async throws {
102+
await SparkSession.builder.clear()
85103
let spark = try await SparkSession.builder.getOrCreate()
86104
try await spark.conf.set("spark.x", "y")
87105
#expect(try await spark.conf.get("spark.x") == "y")
@@ -91,6 +109,7 @@ struct SparkSessionTests {
91109

92110
@Test
93111
func emptyDataFrame() async throws {
112+
await SparkSession.builder.clear()
94113
let spark = try await SparkSession.builder.getOrCreate()
95114
#expect(try await spark.emptyDataFrame.count() == 0)
96115
#expect(try await spark.emptyDataFrame.dtypes.isEmpty)
@@ -100,6 +119,7 @@ struct SparkSessionTests {
100119

101120
@Test
102121
func range() async throws {
122+
await SparkSession.builder.clear()
103123
let spark = try await SparkSession.builder.getOrCreate()
104124
#expect(try await spark.range(10).count() == 10)
105125
#expect(try await spark.range(0, 100).count() == 100)
@@ -110,6 +130,7 @@ struct SparkSessionTests {
110130
#if !os(Linux)
111131
@Test
112132
func sql() async throws {
133+
await SparkSession.builder.clear()
113134
let spark = try await SparkSession.builder.getOrCreate()
114135
let expected = [Row(true, 1, "a")]
115136
if await spark.version.starts(with: "4.") {
@@ -122,6 +143,7 @@ struct SparkSessionTests {
122143

123144
@Test
124145
func table() async throws {
146+
await SparkSession.builder.clear()
125147
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
126148
let spark = try await SparkSession.builder.getOrCreate()
127149
try await SQLHelper.withTable(spark, tableName)({
@@ -133,6 +155,7 @@ struct SparkSessionTests {
133155

134156
@Test
135157
func time() async throws {
158+
await SparkSession.builder.clear()
136159
let spark = try await SparkSession.builder.getOrCreate()
137160
#expect(try await spark.time(spark.range(1000).count) == 1000)
138161
#if !os(Linux)
@@ -144,6 +167,7 @@ struct SparkSessionTests {
144167

145168
@Test
146169
func tag() async throws {
170+
await SparkSession.builder.clear()
147171
let spark = try await SparkSession.builder.getOrCreate()
148172
try await spark.addTag("tag1")
149173
#expect(await spark.getTags() == Set(["tag1"]))
@@ -158,6 +182,7 @@ struct SparkSessionTests {
158182

159183
@Test
160184
func invalidTags() async throws {
185+
await SparkSession.builder.clear()
161186
let spark = try await SparkSession.builder.getOrCreate()
162187
await #expect(throws: SparkConnectError.InvalidArgumentException) {
163188
try await spark.addTag("")
@@ -170,20 +195,23 @@ struct SparkSessionTests {
170195

171196
@Test
172197
func interruptAll() async throws {
198+
await SparkSession.builder.clear()
173199
let spark = try await SparkSession.builder.getOrCreate()
174200
#expect(try await spark.interruptAll() == [])
175201
await spark.stop()
176202
}
177203

178204
@Test
179205
func interruptTag() async throws {
206+
await SparkSession.builder.clear()
180207
let spark = try await SparkSession.builder.getOrCreate()
181208
#expect(try await spark.interruptTag("etl") == [])
182209
await spark.stop()
183210
}
184211

185212
@Test
186213
func interruptOperation() async throws {
214+
await SparkSession.builder.clear()
187215
let spark = try await SparkSession.builder.getOrCreate()
188216
#expect(try await spark.interruptOperation("id") == [])
189217
await spark.stop()

0 commit comments

Comments
 (0)