Skip to content

Commit 0dd07eb

Browse files
committed
[SPARK-52302] Improve stop to use ReleaseSessionRequest
### What changes were proposed in this pull request? This PR aims to improve `SparkSession.stop` to use `ReleaseSessionRequest`. ### Why are the changes needed? For feature parity: closing the session immediately when `spark.stop()` is called. ### Does this PR introduce _any_ user-facing change? Yes, but new behavior is consistent with other Spark Connect clients. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No Closes #175 from dongjoon-hyun/SPARK-52302. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent cfa33b8 commit 0dd07eb

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,17 @@ public actor SparkConnectClient {
8282
self.userContext = userName.toUserContext
8383
}
8484

85-
/// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet.
86-
func stop() {
85+
/// Stop the connection.
86+
func stop() async {
87+
guard self.sessionID != nil else { return }
88+
try? await withGPRC { client in
89+
let service = SparkConnectService.Client(wrapping: client)
90+
var request = Spark_Connect_ReleaseSessionRequest()
91+
request.sessionID = self.sessionID!
92+
request.userContext = self.userContext
93+
request.clientType = self.clientType
94+
_ = try await service.releaseSession(request)
95+
}
8796
}
8897

8998
/// Connect to the `Spark Connect` server with the given session ID string.

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import Testing
2727
struct SparkSessionTests {
2828
@Test
2929
func sparkContext() async throws {
30+
await SparkSession.builder.clear()
3031
let spark = try await SparkSession.builder.getOrCreate()
3132
await #expect(throws: SparkConnectError.UnsupportedOperationException) {
3233
try await spark.sparkContext
@@ -36,12 +37,14 @@ struct SparkSessionTests {
3637

3738
@Test
3839
func stop() async throws {
40+
await SparkSession.builder.clear()
3941
let spark = try await SparkSession.builder.getOrCreate()
4042
await spark.stop()
4143
}
4244

4345
@Test
4446
func newSession() async throws {
47+
await SparkSession.builder.clear()
4548
let spark = try await SparkSession.builder.getOrCreate()
4649
await spark.stop()
4750
let newSpark = try await spark.newSession()
@@ -52,16 +55,31 @@ struct SparkSessionTests {
5255

5356
@Test
5457
func sessionID() async throws {
58+
await SparkSession.builder.clear()
5559
let spark1 = try await SparkSession.builder.getOrCreate()
56-
await spark1.stop()
5760
let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost"
5861
let spark2 = try await SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)").getOrCreate()
5962
await spark2.stop()
6063
#expect(spark1.sessionID == spark2.sessionID)
6164
#expect(spark1 == spark2)
6265
}
6366

67+
@Test
68+
func closedSessionID() async throws {
69+
await SparkSession.builder.clear()
70+
let spark1 = try await SparkSession.builder.getOrCreate()
71+
if await spark1.version >= "4.0.0" {
72+
let sessionID = spark1.sessionID
73+
await spark1.stop()
74+
let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost"
75+
try await #require(throws: Error.self) {
76+
try await SparkSession.builder.remote("\(remote)/;session_id=\(sessionID)").getOrCreate()
77+
}
78+
}
79+
}
80+
6481
@Test func userContext() async throws {
82+
await SparkSession.builder.clear()
6583
let spark = try await SparkSession.builder.getOrCreate()
6684
#if os(macOS) || os(Linux)
6785
let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
@@ -74,6 +92,7 @@ struct SparkSessionTests {
7492

7593
@Test
7694
func version() async throws {
95+
await SparkSession.builder.clear()
7796
let spark = try await SparkSession.builder.getOrCreate()
7897
let version = await spark.version
7998
#expect(version.starts(with: "4.0.0") || version.starts(with: "3.5."))
@@ -82,6 +101,7 @@ struct SparkSessionTests {
82101

83102
@Test
84103
func conf() async throws {
104+
await SparkSession.builder.clear()
85105
let spark = try await SparkSession.builder.getOrCreate()
86106
try await spark.conf.set("spark.x", "y")
87107
#expect(try await spark.conf.get("spark.x") == "y")
@@ -91,6 +111,7 @@ struct SparkSessionTests {
91111

92112
@Test
93113
func emptyDataFrame() async throws {
114+
await SparkSession.builder.clear()
94115
let spark = try await SparkSession.builder.getOrCreate()
95116
#expect(try await spark.emptyDataFrame.count() == 0)
96117
#expect(try await spark.emptyDataFrame.dtypes.isEmpty)
@@ -100,6 +121,7 @@ struct SparkSessionTests {
100121

101122
@Test
102123
func range() async throws {
124+
await SparkSession.builder.clear()
103125
let spark = try await SparkSession.builder.getOrCreate()
104126
#expect(try await spark.range(10).count() == 10)
105127
#expect(try await spark.range(0, 100).count() == 100)
@@ -110,6 +132,7 @@ struct SparkSessionTests {
110132
#if !os(Linux)
111133
@Test
112134
func sql() async throws {
135+
await SparkSession.builder.clear()
113136
let spark = try await SparkSession.builder.getOrCreate()
114137
let expected = [Row(true, 1, "a")]
115138
if await spark.version.starts(with: "4.") {
@@ -122,6 +145,7 @@ struct SparkSessionTests {
122145

123146
@Test
124147
func table() async throws {
148+
await SparkSession.builder.clear()
125149
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
126150
let spark = try await SparkSession.builder.getOrCreate()
127151
try await SQLHelper.withTable(spark, tableName)({
@@ -133,6 +157,7 @@ struct SparkSessionTests {
133157

134158
@Test
135159
func time() async throws {
160+
await SparkSession.builder.clear()
136161
let spark = try await SparkSession.builder.getOrCreate()
137162
#expect(try await spark.time(spark.range(1000).count) == 1000)
138163
#if !os(Linux)
@@ -144,6 +169,7 @@ struct SparkSessionTests {
144169

145170
@Test
146171
func tag() async throws {
172+
await SparkSession.builder.clear()
147173
let spark = try await SparkSession.builder.getOrCreate()
148174
try await spark.addTag("tag1")
149175
#expect(await spark.getTags() == Set(["tag1"]))
@@ -158,6 +184,7 @@ struct SparkSessionTests {
158184

159185
@Test
160186
func invalidTags() async throws {
187+
await SparkSession.builder.clear()
161188
let spark = try await SparkSession.builder.getOrCreate()
162189
await #expect(throws: SparkConnectError.InvalidArgumentException) {
163190
try await spark.addTag("")
@@ -170,20 +197,23 @@ struct SparkSessionTests {
170197

171198
@Test
172199
func interruptAll() async throws {
200+
await SparkSession.builder.clear()
173201
let spark = try await SparkSession.builder.getOrCreate()
174202
#expect(try await spark.interruptAll() == [])
175203
await spark.stop()
176204
}
177205

178206
@Test
179207
func interruptTag() async throws {
208+
await SparkSession.builder.clear()
180209
let spark = try await SparkSession.builder.getOrCreate()
181210
#expect(try await spark.interruptTag("etl") == [])
182211
await spark.stop()
183212
}
184213

185214
@Test
186215
func interruptOperation() async throws {
216+
await SparkSession.builder.clear()
187217
let spark = try await SparkSession.builder.getOrCreate()
188218
#expect(try await spark.interruptOperation("id") == [])
189219
await spark.stop()

0 commit comments

Comments
 (0)