Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions Sources/SparkConnect/ErrorUtils.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
import Foundation

/// Utility functions like `org.apache.spark.util.SparkErrorUtils`.
public enum ErrorUtils {
public static func tryWithSafeFinally<T>(
_ block: () async throws -> T, _ finallyBlock: () async throws -> Void
) async rethrows -> T {
let result: T
do {
result = try await block()
try await finallyBlock()
} catch {
try? await finallyBlock()
throw error
}
return result
}
}
9 changes: 8 additions & 1 deletion Tests/SparkConnectTests/CatalogTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ struct CatalogTests {
func databaseExists() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.catalog.databaseExists("default"))
#expect(try await spark.catalog.databaseExists("not_exist_database") == false)

let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
#expect(try await spark.catalog.databaseExists(dbName) == false)
try await SQLHelper.withDatabase(spark, dbName) ({
_ = try await spark.sql("CREATE DATABASE \(dbName)").count()
#expect(try await spark.catalog.databaseExists(dbName))
})
#expect(try await spark.catalog.databaseExists(dbName) == false)
await spark.stop()
}
#endif
Expand Down
9 changes: 5 additions & 4 deletions Tests/SparkConnectTests/DataFrameReaderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ struct DataFrameReaderTests {

@Test
func table() async throws {
let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0)
#expect(try await spark.read.table(tableName).count() == 2)
#expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
try await SQLHelper.withTable(spark, tableName)({
_ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count()
#expect(try await spark.read.table(tableName).count() == 2)
})
await spark.stop()
}
}
56 changes: 56 additions & 0 deletions Tests/SparkConnectTests/SQLHelper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//

import Foundation
import Testing

@testable import SparkConnect

/// A test utility
struct SQLHelper {
public static func withDatabase(_ spark: SparkSession, _ dbNames: String...) -> (
() async throws -> Void
) async throws -> Void {
func body(_ f: () async throws -> Void) async throws {
try await ErrorUtils.tryWithSafeFinally(
f,
{
for name in dbNames {
_ = try await spark.sql("DROP DATABASE IF EXISTS \(name) CASCADE").count()
}
})
}
return body
}

public static func withTable(_ spark: SparkSession, _ tableNames: String...) -> (
() async throws -> Void
) async throws -> Void {
func body(_ f: () async throws -> Void) async throws {
try await ErrorUtils.tryWithSafeFinally(
f,
{
for name in tableNames {
_ = try await spark.sql("DROP TABLE IF EXISTS \(name)").count()
}
})
}
return body
}
}
9 changes: 5 additions & 4 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ struct SparkSessionTests {

@Test
func table() async throws {
let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0)
#expect(try await spark.table(tableName).count() == 2)
#expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
try await SQLHelper.withTable(spark, tableName)({
_ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count()
#expect(try await spark.table(tableName).count() == 2)
})
await spark.stop()
}

Expand Down
Loading