From 2ffc385bc8529f75bfe824d08e9df05df86f067e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 6 Apr 2025 19:02:41 +0900 Subject: [PATCH] [SPARK-51730] Add `Catalog` actor and support catalog/database APIs --- Sources/SparkConnect/Catalog.swift | 202 +++++++++++++++++++++ Sources/SparkConnect/SparkSession.swift | 7 + Sources/SparkConnect/TypeAliases.swift | 1 + Tests/SparkConnectTests/CatalogTests.swift | 107 +++++++++++ 4 files changed, 317 insertions(+) create mode 100644 Sources/SparkConnect/Catalog.swift create mode 100644 Tests/SparkConnectTests/CatalogTests.swift diff --git a/Sources/SparkConnect/Catalog.swift b/Sources/SparkConnect/Catalog.swift new file mode 100644 index 0000000..ed37bf7 --- /dev/null +++ b/Sources/SparkConnect/Catalog.swift @@ -0,0 +1,202 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +import Foundation + +public struct CatalogMetadata: Sendable, Equatable { + public var name: String + public var description: String? = nil +} + +public struct Database: Sendable, Equatable { + public var name: String + public var catalog: String? = nil + public var description: String? = nil + public var locationUri: String +} + +// TODO: Rename `SparkTable` to `Table` after removing Arrow and Flatbuffer +// from `SparkConnect` module. Currently, `SparkTable` is used to avoid the name conflict. +public struct SparkTable: Sendable, Equatable { + public var name: String + public var catalog: String? + public var namespace: [String]? + public var description: String? + public var tableType: String + public var isTemporary: Bool + public var database: String? { + get { + guard let namespace else { + return nil + } + if namespace.count == 1 { + return namespace[0] + } else { + return nil + } + } + } +} + +public struct Column: Sendable, Equatable { + public var name: String + public var description: String? + public var dataType: String + public var nullable: Bool + public var isPartition: Bool + public var isBucket: Bool + public var isCluster: Bool +} + +public struct Function: Sendable, Equatable { + public var name: String + public var catalog: String? + public var namespace: [String]? + public var description: String? + public var className: String + public var isTemporary: Bool +} + +/// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc. +/// To access this, use ``SparkSession.catalog``. +public actor Catalog: Sendable { + var spark: SparkSession + + init(spark: SparkSession) { + self.spark = spark + } + + /// A helper method to create a `Spark_Connect_Catalog`-based plan. + /// - Parameter f: A lambda function to create `Spark_Connect_Catalog`. + /// - Returns: A ``DataFrame`` contains the result of the given catalog operation. + private func getDataFrame(_ f: () -> Spark_Connect_Catalog) -> DataFrame { + var relation = Relation() + relation.catalog = f() + var plan = Plan() + plan.opType = .root(relation) + return DataFrame(spark: spark, plan: plan) + } + + /// Returns the current default catalog in this session. + /// - Returns: A catalog name. + public func currentCatalog() async throws -> String { + let df = getDataFrame({ + var catalog = Spark_Connect_Catalog() + catalog.catType = .currentCatalog(Spark_Connect_CurrentCatalog()) + return catalog + }) + return try await df.collect()[0][0]! + } + + /// Sets the current default catalog in this session. + /// - Parameter catalogName: name of the catalog to set + public func setCurrentCatalog(_ catalogName: String) async throws { + let df = getDataFrame({ + var setCurrentCatalog = Spark_Connect_SetCurrentCatalog() + setCurrentCatalog.catalogName = catalogName + + var catalog = Spark_Connect_Catalog() + catalog.catType = .setCurrentCatalog(setCurrentCatalog) + return catalog + }) + _ = try await df.count() + } + + /// Returns a list of catalogs in this session. + /// - Returns: A list of ``CatalogMetadata``. + public func listCatalogs(pattern: String? = nil) async throws -> [CatalogMetadata] { + let df = getDataFrame({ + var listCatalogs = Spark_Connect_ListCatalogs() + if let pattern { + listCatalogs.pattern = pattern + } + var catalog = Spark_Connect_Catalog() + catalog.catType = .listCatalogs(listCatalogs) + return catalog + }) + return try await df.collect().map { + CatalogMetadata(name: $0[0]!, description: $0[1]) + } + } + + /// Returns the current default database in this session. + /// - Returns: The current default database name. + public func currentDatabase() async throws -> String { + let df = getDataFrame({ + var catalog = Spark_Connect_Catalog() + catalog.catType = .currentDatabase(Spark_Connect_CurrentDatabase()) + return catalog + }) + return try await df.collect()[0][0]! + } + + /// Sets the current default database in this session. + /// - Parameter dbName: name of the catalog to set + public func setCurrentDatabase(_ dbName: String) async throws { + let df = getDataFrame({ + var setCurrentDatabase = Spark_Connect_SetCurrentDatabase() + setCurrentDatabase.dbName = dbName + + var catalog = Spark_Connect_Catalog() + catalog.catType = .setCurrentDatabase(setCurrentDatabase) + return catalog + }) + _ = try await df.count() + } + + /// Returns a list of databases available across all sessions. + /// - Parameter pattern: The pattern that the database name needs to match. + /// - Returns: A list of ``Database``. + public func listDatabases(pattern: String? = nil) async throws -> [Database] { + let df = getDataFrame({ + var listDatabases = Spark_Connect_ListDatabases() + if let pattern { + listDatabases.pattern = pattern + } + var catalog = Spark_Connect_Catalog() + catalog.catType = .listDatabases(listDatabases) + return catalog + }) + return try await df.collect().map { + Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!) + } + } + + /// Get the database with the specified name. + /// - Parameter dbName: name of the database to get. + /// - Returns: The database found by the name. + public func getDatabase(_ dbName: String) async throws -> Database { + let df = getDataFrame({ + var db = Spark_Connect_GetDatabase() + db.dbName = dbName + var catalog = Spark_Connect_Catalog() + catalog.catType = .getDatabase(db) + return catalog + }) + return try await df.collect().map { + Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!) + }.first! + } + + /// Check if the database with the specified name exists. + /// - Parameter dbName: name of the database to check existence + /// - Returns: Indicating whether the database exists. + public func databaseExists(_ dbName: String) async throws -> Bool { + return try await self.listDatabases(pattern: dbName).count > 0 + } +} diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 2fa583c..3b07c27 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -84,6 +84,13 @@ public actor SparkSession { } } + /// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc. + public var catalog: Catalog { + get { + return Catalog(spark: self) + } + } + /// Stop the current client. public func stop() async { await client.stop() diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index f186729..6a700d6 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -33,6 +33,7 @@ typealias Limit = Spark_Connect_Limit typealias MapType = Spark_Connect_DataType.Map typealias NamedTable = Spark_Connect_Read.NamedTable typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze +typealias OneOf_CatType = Spark_Connect_Catalog.OneOf_CatType typealias Plan = Spark_Connect_Plan typealias Project = Spark_Connect_Project typealias Range = Spark_Connect_Range diff --git a/Tests/SparkConnectTests/CatalogTests.swift b/Tests/SparkConnectTests/CatalogTests.swift new file mode 100644 index 0000000..f49f2db --- /dev/null +++ b/Tests/SparkConnectTests/CatalogTests.swift @@ -0,0 +1,107 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation +import Testing + +@testable import SparkConnect + +/// A test suite for `Catalog` +struct CatalogTests { +#if !os(Linux) + @Test + func currentCatalog() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.currentCatalog() == "spark_catalog") + await spark.stop() + } + + @Test + func setCurrentCatalog() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.catalog.setCurrentCatalog("spark_catalog") + try await #require(throws: Error.self) { + try await spark.catalog.setCurrentCatalog("not_exist_catalog") + } + await spark.stop() + } + + @Test + func listCatalogs() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")]) + #expect(try await spark.catalog.listCatalogs(pattern: "*") == [CatalogMetadata(name: "spark_catalog")]) + #expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0) + await spark.stop() + } + + @Test + func currentDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.currentDatabase() == "default") + await spark.stop() + } + + @Test + func setCurrentDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.catalog.setCurrentDatabase("default") + try await #require(throws: Error.self) { + try await spark.catalog.setCurrentDatabase("not_exist_database") + } + await spark.stop() + } + + @Test + func listDatabases() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let dbs = try await spark.catalog.listDatabases() + #expect(dbs.count == 1) + #expect(dbs[0].name == "default") + #expect(dbs[0].catalog == "spark_catalog") + #expect(dbs[0].description == "default database") + #expect(dbs[0].locationUri.hasSuffix("spark-warehouse")) + #expect(try await spark.catalog.listDatabases(pattern: "*") == dbs) + #expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0) + await spark.stop() + } + + @Test + func getDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let db = try await spark.catalog.getDatabase("default") + #expect(db.name == "default") + #expect(db.catalog == "spark_catalog") + #expect(db.description == "default database") + #expect(db.locationUri.hasSuffix("spark-warehouse")) + try await #require(throws: Error.self) { + try await spark.catalog.getDatabase("not_exist_database") + } + await spark.stop() + } + + @Test + func databaseExists() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.databaseExists("default")) + #expect(try await spark.catalog.databaseExists("not_exist_database") == false) + await spark.stop() + } +#endif +}