Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions Sources/SparkConnect/Catalog.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
import Foundation

public struct CatalogMetadata: Sendable, Equatable {
public var name: String
public var description: String? = nil
}

public struct Database: Sendable, Equatable {
public var name: String
public var catalog: String? = nil
public var description: String? = nil
public var locationUri: String
}

// TODO: Rename `SparkTable` to `Table` after removing Arrow and Flatbuffer
// from `SparkConnect` module. Currently, `SparkTable` is used to avoid the name conflict.
public struct SparkTable: Sendable, Equatable {
public var name: String
public var catalog: String?
public var namespace: [String]?
public var description: String?
public var tableType: String
public var isTemporary: Bool
public var database: String? {
get {
guard let namespace else {
return nil
}
if namespace.count == 1 {
return namespace[0]
} else {
return nil
}
}
}
}

public struct Column: Sendable, Equatable {
public var name: String
public var description: String?
public var dataType: String
public var nullable: Bool
public var isPartition: Bool
public var isBucket: Bool
public var isCluster: Bool
}

public struct Function: Sendable, Equatable {
public var name: String
public var catalog: String?
public var namespace: [String]?
public var description: String?
public var className: String
public var isTemporary: Bool
}

/// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc.
/// To access this, use ``SparkSession.catalog``.
public actor Catalog: Sendable {
var spark: SparkSession

init(spark: SparkSession) {
self.spark = spark
}

/// A helper method to create a `Spark_Connect_Catalog`-based plan.
/// - Parameter f: A lambda function to create `Spark_Connect_Catalog`.
/// - Returns: A ``DataFrame`` contains the result of the given catalog operation.
private func getDataFrame(_ f: () -> Spark_Connect_Catalog) -> DataFrame {
var relation = Relation()
relation.catalog = f()
var plan = Plan()
plan.opType = .root(relation)
return DataFrame(spark: spark, plan: plan)
}

/// Returns the current default catalog in this session.
/// - Returns: A catalog name.
public func currentCatalog() async throws -> String {
let df = getDataFrame({
var catalog = Spark_Connect_Catalog()
catalog.catType = .currentCatalog(Spark_Connect_CurrentCatalog())
return catalog
})
return try await df.collect()[0][0]!
}

/// Sets the current default catalog in this session.
/// - Parameter catalogName: name of the catalog to set
public func setCurrentCatalog(_ catalogName: String) async throws {
let df = getDataFrame({
var setCurrentCatalog = Spark_Connect_SetCurrentCatalog()
setCurrentCatalog.catalogName = catalogName

var catalog = Spark_Connect_Catalog()
catalog.catType = .setCurrentCatalog(setCurrentCatalog)
return catalog
})
_ = try await df.count()
}

/// Returns a list of catalogs in this session.
/// - Returns: A list of ``CatalogMetadata``.
public func listCatalogs(pattern: String? = nil) async throws -> [CatalogMetadata] {
let df = getDataFrame({
var listCatalogs = Spark_Connect_ListCatalogs()
if let pattern {
listCatalogs.pattern = pattern
}
var catalog = Spark_Connect_Catalog()
catalog.catType = .listCatalogs(listCatalogs)
return catalog
})
return try await df.collect().map {
CatalogMetadata(name: $0[0]!, description: $0[1])
}
}

/// Returns the current default database in this session.
/// - Returns: The current default database name.
public func currentDatabase() async throws -> String {
let df = getDataFrame({
var catalog = Spark_Connect_Catalog()
catalog.catType = .currentDatabase(Spark_Connect_CurrentDatabase())
return catalog
})
return try await df.collect()[0][0]!
}

/// Sets the current default database in this session.
/// - Parameter dbName: name of the catalog to set
public func setCurrentDatabase(_ dbName: String) async throws {
let df = getDataFrame({
var setCurrentDatabase = Spark_Connect_SetCurrentDatabase()
setCurrentDatabase.dbName = dbName

var catalog = Spark_Connect_Catalog()
catalog.catType = .setCurrentDatabase(setCurrentDatabase)
return catalog
})
_ = try await df.count()
}

/// Returns a list of databases available across all sessions.
/// - Parameter pattern: The pattern that the database name needs to match.
/// - Returns: A list of ``Database``.
public func listDatabases(pattern: String? = nil) async throws -> [Database] {
let df = getDataFrame({
var listDatabases = Spark_Connect_ListDatabases()
if let pattern {
listDatabases.pattern = pattern
}
var catalog = Spark_Connect_Catalog()
catalog.catType = .listDatabases(listDatabases)
return catalog
})
return try await df.collect().map {
Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!)
}
}

/// Get the database with the specified name.
/// - Parameter dbName: name of the database to get.
/// - Returns: The database found by the name.
public func getDatabase(_ dbName: String) async throws -> Database {
let df = getDataFrame({
var db = Spark_Connect_GetDatabase()
db.dbName = dbName
var catalog = Spark_Connect_Catalog()
catalog.catType = .getDatabase(db)
return catalog
})
return try await df.collect().map {
Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!)
}.first!
}

/// Check if the database with the specified name exists.
/// - Parameter dbName: name of the database to check existence
/// - Returns: Indicating whether the database exists.
public func databaseExists(_ dbName: String) async throws -> Bool {
return try await self.listDatabases(pattern: dbName).count > 0
}
}
7 changes: 7 additions & 0 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ public actor SparkSession {
}
}

/// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc.
public var catalog: Catalog {
get {
return Catalog(spark: self)
}
}

/// Stop the current client.
public func stop() async {
await client.stop()
Expand Down
1 change: 1 addition & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ typealias Limit = Spark_Connect_Limit
typealias MapType = Spark_Connect_DataType.Map
typealias NamedTable = Spark_Connect_Read.NamedTable
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
typealias OneOf_CatType = Spark_Connect_Catalog.OneOf_CatType
typealias Plan = Spark_Connect_Plan
typealias Project = Spark_Connect_Project
typealias Range = Spark_Connect_Range
Expand Down
107 changes: 107 additions & 0 deletions Tests/SparkConnectTests/CatalogTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//

import Foundation
import Testing

@testable import SparkConnect

/// A test suite for `Catalog`
struct CatalogTests {
#if !os(Linux)
@Test
func currentCatalog() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.catalog.currentCatalog() == "spark_catalog")
await spark.stop()
}

@Test
func setCurrentCatalog() async throws {
let spark = try await SparkSession.builder.getOrCreate()
try await spark.catalog.setCurrentCatalog("spark_catalog")
try await #require(throws: Error.self) {
try await spark.catalog.setCurrentCatalog("not_exist_catalog")
}
await spark.stop()
}

@Test
func listCatalogs() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")])
#expect(try await spark.catalog.listCatalogs(pattern: "*") == [CatalogMetadata(name: "spark_catalog")])
#expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0)
await spark.stop()
}

@Test
func currentDatabase() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.catalog.currentDatabase() == "default")
await spark.stop()
}

@Test
func setCurrentDatabase() async throws {
let spark = try await SparkSession.builder.getOrCreate()
try await spark.catalog.setCurrentDatabase("default")
try await #require(throws: Error.self) {
try await spark.catalog.setCurrentDatabase("not_exist_database")
}
await spark.stop()
}

@Test
func listDatabases() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let dbs = try await spark.catalog.listDatabases()
#expect(dbs.count == 1)
#expect(dbs[0].name == "default")
#expect(dbs[0].catalog == "spark_catalog")
#expect(dbs[0].description == "default database")
#expect(dbs[0].locationUri.hasSuffix("spark-warehouse"))
#expect(try await spark.catalog.listDatabases(pattern: "*") == dbs)
#expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0)
await spark.stop()
}

@Test
func getDatabase() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let db = try await spark.catalog.getDatabase("default")
#expect(db.name == "default")
#expect(db.catalog == "spark_catalog")
#expect(db.description == "default database")
#expect(db.locationUri.hasSuffix("spark-warehouse"))
try await #require(throws: Error.self) {
try await spark.catalog.getDatabase("not_exist_database")
}
await spark.stop()
}

@Test
func databaseExists() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.catalog.databaseExists("default"))
#expect(try await spark.catalog.databaseExists("not_exist_database") == false)
await spark.stop()
}
#endif
}
Loading