Skip to content

Commit 4f8a57e

Browse files
committed
[SPARK-51759] Add ErrorUtils and SQLHelper
1 parent 24150c4 commit 4f8a57e

File tree

5 files changed

+110
-9
lines changed

5 files changed

+110
-9
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Foundation
20+
21+
/// Utility functions like `org.apache.spark.util.SparkErrorUtils`.
22+
public enum ErrorUtils {
23+
public static func tryWithSafeFinally<T>(
24+
_ block: () async throws -> T, _ finallyBlock: () async throws -> Void
25+
) async rethrows -> T {
26+
let result: T
27+
do {
28+
result = try await block()
29+
try await finallyBlock()
30+
} catch {
31+
try await finallyBlock()
32+
throw error
33+
}
34+
return result
35+
}
36+
}

Tests/SparkConnectTests/CatalogTests.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,14 @@ struct CatalogTests {
100100
func databaseExists() async throws {
101101
let spark = try await SparkSession.builder.getOrCreate()
102102
#expect(try await spark.catalog.databaseExists("default"))
103-
#expect(try await spark.catalog.databaseExists("not_exist_database") == false)
103+
104+
let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
105+
#expect(try await spark.catalog.databaseExists(dbName) == false)
106+
try await SQLHelper.withDatabase(spark, dbName) ({
107+
_ = try await spark.sql("CREATE DATABASE \(dbName)").count()
108+
#expect(try await spark.catalog.databaseExists(dbName))
109+
})
110+
#expect(try await spark.catalog.databaseExists(dbName) == false)
104111
await spark.stop()
105112
}
106113
#endif

Tests/SparkConnectTests/DataFrameReaderTests.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,12 @@ struct DataFrameReaderTests {
6767

6868
@Test
6969
func table() async throws {
70-
let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
70+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
7171
let spark = try await SparkSession.builder.getOrCreate()
72-
#expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0)
73-
#expect(try await spark.read.table(tableName).count() == 2)
74-
#expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
72+
try await SQLHelper.withTable(spark, tableName)({
73+
_ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count()
74+
#expect(try await spark.read.table(tableName).count() == 2)
75+
})
7576
await spark.stop()
7677
}
7778
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import Testing
22+
23+
@testable import SparkConnect
24+
25+
/// A test utility
26+
struct SQLHelper {
27+
public static func withDatabase(_ spark: SparkSession, _ dbNames: String...) -> (
28+
() async throws -> Void
29+
) async throws -> Void {
30+
func body(_ f: () async throws -> Void) async throws {
31+
try await ErrorUtils.tryWithSafeFinally(
32+
f,
33+
{
34+
for name in dbNames {
35+
_ = try await spark.sql("DROP DATABASE IF EXISTS \(name) CASCADE").count()
36+
}
37+
})
38+
}
39+
return body
40+
}
41+
42+
public static func withTable(_ spark: SparkSession, _ tableNames: String...) -> (
43+
() async throws -> Void
44+
) async throws -> Void {
45+
func body(_ f: () async throws -> Void) async throws {
46+
try await ErrorUtils.tryWithSafeFinally(
47+
f,
48+
{
49+
for name in tableNames {
50+
_ = try await spark.sql("DROP TABLE IF EXISTS \(name)").count()
51+
}
52+
})
53+
}
54+
return body
55+
}
56+
}

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@ struct SparkSessionTests {
7777

7878
@Test
7979
func table() async throws {
80-
let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
80+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
8181
let spark = try await SparkSession.builder.getOrCreate()
82-
#expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0)
83-
#expect(try await spark.table(tableName).count() == 2)
84-
#expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
82+
try await SQLHelper.withTable(spark, tableName)({
83+
_ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count()
84+
#expect(try await spark.table(tableName).count() == 2)
85+
})
8586
await spark.stop()
8687
}
8788

0 commit comments

Comments
 (0)