Skip to content

Commit 9afc485

Browse files
committed
[SPARK-51730] Add Catalog actor and support catalog/database APIs
1 parent 91d6c98 commit 9afc485

File tree

4 files changed

+315
-0
lines changed

4 files changed

+315
-0
lines changed

Sources/SparkConnect/Catalog.swift

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Foundation
20+
21+
public struct CatalogMetadata: Sendable, Equatable {
22+
public var name: String
23+
public var description: String? = nil
24+
}
25+
26+
public struct Database: Sendable, Equatable {
27+
public var name: String
28+
public var catalog: String? = nil
29+
public var description: String? = nil
30+
public var locationUri: String
31+
}
32+
33+
// TODO: Rename `SparkTable` to `Table` after removing Arrow and Flatbuffer
34+
// from `SparkConnect` module. Currently, `SparkTable` is used to avoid the name conflict.
35+
public struct SparkTable: Sendable, Equatable {
36+
public var name: String
37+
public var catalog: String?
38+
public var namespace: [String]?
39+
public var description: String?
40+
public var tableType: String
41+
public var isTemporary: Bool
42+
public var database: String? {
43+
get {
44+
guard let namespace else {
45+
return nil
46+
}
47+
if namespace.count == 1 {
48+
return namespace[0]
49+
} else {
50+
return nil
51+
}
52+
}
53+
}
54+
}
55+
56+
public struct Column: Sendable, Equatable {
57+
public var name: String
58+
public var description: String?
59+
public var dataType: String
60+
public var nullable: Bool
61+
public var isPartition: Bool
62+
public var isBucket: Bool
63+
public var isCluster: Bool
64+
}
65+
66+
public struct Function: Sendable, Equatable {
67+
public var name: String
68+
public var catalog: String?
69+
public var namespace: [String]?
70+
public var description: String?
71+
public var className: String
72+
public var isTemporary: Bool
73+
}
74+
75+
/// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc.
76+
/// To access this, use ``SparkSession.catalog``.
77+
public actor Catalog: Sendable {
78+
var spark: SparkSession
79+
80+
init(spark: SparkSession) {
81+
self.spark = spark
82+
}
83+
84+
/// A helper method to create a `Spark_Connect_Catalog`-based plan.
85+
/// - Parameter f: A lambda function to create `Spark_Connect_Catalog`.
86+
/// - Returns: A ``DataFrame`` contains the result of the given catalog operation.
87+
private func getDataFrame(_ f: () -> Spark_Connect_Catalog) -> DataFrame {
88+
var relation = Relation()
89+
relation.catalog = f()
90+
var plan = Plan()
91+
plan.opType = .root(relation)
92+
return DataFrame(spark: spark, plan: plan)
93+
}
94+
95+
/// Returns the current default catalog in this session.
96+
/// - Returns: A catalog name.
97+
public func currentCatalog() async throws -> String {
98+
let df = getDataFrame({
99+
var catalog = Spark_Connect_Catalog()
100+
catalog.catType = .currentCatalog(Spark_Connect_CurrentCatalog())
101+
return catalog
102+
})
103+
return try await df.collect()[0][0]!
104+
}
105+
106+
/// Sets the current default catalog in this session.
107+
/// - Parameter catalogName: name of the catalog to set
108+
public func setCurrentCatalog(_ catalogName: String) async throws {
109+
let df = getDataFrame({
110+
var setCurrentCatalog = Spark_Connect_SetCurrentCatalog()
111+
setCurrentCatalog.catalogName = catalogName
112+
113+
var catalog = Spark_Connect_Catalog()
114+
catalog.catType = .setCurrentCatalog(setCurrentCatalog)
115+
return catalog
116+
})
117+
_ = try await df.count()
118+
}
119+
120+
/// Returns a list of catalogs in this session.
121+
/// - Returns: A list of ``CatalogMetadata``.
122+
public func listCatalogs(pattern: String? = nil) async throws -> [CatalogMetadata] {
123+
let df = getDataFrame({
124+
var listCatalogs = Spark_Connect_ListCatalogs()
125+
if let pattern {
126+
listCatalogs.pattern = pattern
127+
}
128+
var catalog = Spark_Connect_Catalog()
129+
catalog.catType = .listCatalogs(listCatalogs)
130+
return catalog
131+
})
132+
return try await df.collect().map {
133+
CatalogMetadata(name: $0[0]!, description: $0[1])
134+
}
135+
}
136+
137+
/// Returns the current default database in this session.
138+
/// - Returns: The current default database name.
139+
public func currentDatabase() async throws -> String {
140+
let df = getDataFrame({
141+
var catalog = Spark_Connect_Catalog()
142+
catalog.catType = .currentDatabase(Spark_Connect_CurrentDatabase())
143+
return catalog
144+
})
145+
return try await df.collect()[0][0]!
146+
}
147+
148+
/// Sets the current default database in this session.
149+
/// - Parameter dbName: name of the catalog to set
150+
public func setCurrentDatabase(_ dbName: String) async throws {
151+
let df = getDataFrame({
152+
var setCurrentDatabase = Spark_Connect_SetCurrentDatabase()
153+
setCurrentDatabase.dbName = dbName
154+
155+
var catalog = Spark_Connect_Catalog()
156+
catalog.catType = .setCurrentDatabase(setCurrentDatabase)
157+
return catalog
158+
})
159+
_ = try await df.count()
160+
}
161+
162+
/// Returns a list of databases available across all sessions.
163+
/// - Parameter pattern: The pattern that the database name needs to match.
164+
/// - Returns: A list of ``Database``.
165+
public func listDatabases(pattern: String? = nil) async throws -> [Database] {
166+
let df = getDataFrame({
167+
var listDatabases = Spark_Connect_ListDatabases()
168+
if let pattern {
169+
listDatabases.pattern = pattern
170+
}
171+
var catalog = Spark_Connect_Catalog()
172+
catalog.catType = .listDatabases(listDatabases)
173+
return catalog
174+
})
175+
return try await df.collect().map {
176+
Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!)
177+
}
178+
}
179+
180+
/// Get the database with the specified name.
181+
/// - Parameter dbName: name of the database to get.
182+
/// - Returns: The database found by the name.
183+
public func getDatabase(_ dbName: String) async throws -> Database {
184+
let df = getDataFrame({
185+
var db = Spark_Connect_GetDatabase()
186+
db.dbName = dbName
187+
var catalog = Spark_Connect_Catalog()
188+
catalog.catType = .getDatabase(db)
189+
return catalog
190+
})
191+
return try await df.collect().map {
192+
Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!)
193+
}.first!
194+
}
195+
196+
/// Check if the database with the specified name exists.
197+
/// - Parameter dbName: name of the database to check existence
198+
/// - Returns: Indicating whether the database exists.
199+
public func databaseExists(_ dbName: String) async throws -> Bool {
200+
return try await self.listDatabases(pattern: dbName).count > 0
201+
}
202+
}

Sources/SparkConnect/SparkSession.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ public actor SparkSession {
8484
}
8585
}
8686

87+
/// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc.
88+
public var catalog: Catalog {
89+
get {
90+
return Catalog(spark: self)
91+
}
92+
}
93+
8794
/// Stop the current client.
8895
public func stop() async {
8996
await client.stop()

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ typealias Limit = Spark_Connect_Limit
3333
typealias MapType = Spark_Connect_DataType.Map
3434
typealias NamedTable = Spark_Connect_Read.NamedTable
3535
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
36+
typealias OneOf_CatType = Spark_Connect_Catalog.OneOf_CatType
3637
typealias Plan = Spark_Connect_Plan
3738
typealias Project = Spark_Connect_Project
3839
typealias Range = Spark_Connect_Range
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import Testing
22+
23+
@testable import SparkConnect
24+
25+
/// A test suite for `Catalog`
26+
struct CatalogTests {
27+
@Test
28+
func currentCatalog() async throws {
29+
let spark = try await SparkSession.builder.getOrCreate()
30+
#expect(try await spark.catalog.currentCatalog() == "spark_catalog")
31+
await spark.stop()
32+
}
33+
34+
@Test
35+
func setCurrentCatalog() async throws {
36+
let spark = try await SparkSession.builder.getOrCreate()
37+
try await spark.catalog.setCurrentCatalog("spark_catalog")
38+
try await #require(throws: Error.self) {
39+
try await spark.catalog.setCurrentCatalog("not_exist_catalog")
40+
}
41+
await spark.stop()
42+
}
43+
44+
@Test
45+
func listCatalogs() async throws {
46+
let spark = try await SparkSession.builder.getOrCreate()
47+
#expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")])
48+
#expect(try await spark.catalog.listCatalogs(pattern: "*") == [CatalogMetadata(name: "spark_catalog")])
49+
#expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0)
50+
await spark.stop()
51+
}
52+
53+
@Test
54+
func currentDatabase() async throws {
55+
let spark = try await SparkSession.builder.getOrCreate()
56+
#expect(try await spark.catalog.currentDatabase() == "default")
57+
await spark.stop()
58+
}
59+
60+
@Test
61+
func setCurrentDatabase() async throws {
62+
let spark = try await SparkSession.builder.getOrCreate()
63+
try await spark.catalog.setCurrentDatabase("default")
64+
try await #require(throws: Error.self) {
65+
try await spark.catalog.setCurrentDatabase("not_exist_database")
66+
}
67+
await spark.stop()
68+
}
69+
70+
@Test
71+
func listDatabases() async throws {
72+
let spark = try await SparkSession.builder.getOrCreate()
73+
let dbs = try await spark.catalog.listDatabases()
74+
#expect(dbs.count == 1)
75+
#expect(dbs[0].name == "default")
76+
#expect(dbs[0].catalog == "spark_catalog")
77+
#expect(dbs[0].description == "default database")
78+
#expect(dbs[0].locationUri.hasSuffix("spark-warehouse"))
79+
#expect(try await spark.catalog.listDatabases(pattern: "*") == dbs)
80+
#expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0)
81+
await spark.stop()
82+
}
83+
84+
@Test
85+
func getDatabase() async throws {
86+
let spark = try await SparkSession.builder.getOrCreate()
87+
let db = try await spark.catalog.getDatabase("default")
88+
#expect(db.name == "default")
89+
#expect(db.catalog == "spark_catalog")
90+
#expect(db.description == "default database")
91+
#expect(db.locationUri.hasSuffix("spark-warehouse"))
92+
try await #require(throws: Error.self) {
93+
try await spark.catalog.getDatabase("not_exist_database")
94+
}
95+
await spark.stop()
96+
}
97+
98+
@Test
99+
func databaseExists() async throws {
100+
let spark = try await SparkSession.builder.getOrCreate()
101+
#expect(try await spark.catalog.databaseExists("default"))
102+
#expect(try await spark.catalog.databaseExists("not_exist_database") == false)
103+
await spark.stop()
104+
}
105+
}

0 commit comments

Comments
 (0)