Skip to content

Commit ca07010

Browse files
committed
[SPARK-51730] Add Catalog actor and support catalog/database APIs
### What changes were proposed in this pull request? This PR aims to add `Catalog` actor and support `catalog/database` APIs. Other APIs (`table/function/column`) will be added independently as the second part. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition to the unreleased version. ### How was this patch tested? Pass the CIs and manual test on MacOS environment. ``` $ swift test --filter CatalogTests ... 􀟈 Test databaseExists() started. 􀟈 Test setCurrentCatalog() started. 􀟈 Test setCurrentDatabase() started. 􀟈 Test getDatabase() started. 􀟈 Test listDatabases() started. 􀟈 Test listCatalogs() started. 􀟈 Test currentDatabase() started. 􀟈 Test currentCatalog() started. 􁁛 Test currentDatabase() passed after 0.061 seconds. 􁁛 Test currentCatalog() passed after 0.061 seconds. 􁁛 Test setCurrentCatalog() passed after 0.067 seconds. 􁁛 Test setCurrentDatabase() passed after 0.071 seconds. 􁁛 Test getDatabase() passed after 0.138 seconds. 􁁛 Test listCatalogs() passed after 0.143 seconds. 􁁛 Test databaseExists() passed after 0.158 seconds. 􁁛 Test listDatabases() passed after 0.188 seconds. 􁁛 Suite CatalogTests passed after 0.189 seconds. 􁁛 Test run with 8 tests passed after 0.189 seconds. ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44 from dongjoon-hyun/SPARK-51730. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 91d6c98 commit ca07010

File tree

4 files changed

+317
-0
lines changed

4 files changed

+317
-0
lines changed

Sources/SparkConnect/Catalog.swift

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Foundation
20+
21+
public struct CatalogMetadata: Sendable, Equatable {
22+
public var name: String
23+
public var description: String? = nil
24+
}
25+
26+
public struct Database: Sendable, Equatable {
27+
public var name: String
28+
public var catalog: String? = nil
29+
public var description: String? = nil
30+
public var locationUri: String
31+
}
32+
33+
// TODO: Rename `SparkTable` to `Table` after removing Arrow and Flatbuffer
34+
// from `SparkConnect` module. Currently, `SparkTable` is used to avoid the name conflict.
35+
public struct SparkTable: Sendable, Equatable {
36+
public var name: String
37+
public var catalog: String?
38+
public var namespace: [String]?
39+
public var description: String?
40+
public var tableType: String
41+
public var isTemporary: Bool
42+
public var database: String? {
43+
get {
44+
guard let namespace else {
45+
return nil
46+
}
47+
if namespace.count == 1 {
48+
return namespace[0]
49+
} else {
50+
return nil
51+
}
52+
}
53+
}
54+
}
55+
56+
public struct Column: Sendable, Equatable {
57+
public var name: String
58+
public var description: String?
59+
public var dataType: String
60+
public var nullable: Bool
61+
public var isPartition: Bool
62+
public var isBucket: Bool
63+
public var isCluster: Bool
64+
}
65+
66+
public struct Function: Sendable, Equatable {
67+
public var name: String
68+
public var catalog: String?
69+
public var namespace: [String]?
70+
public var description: String?
71+
public var className: String
72+
public var isTemporary: Bool
73+
}
74+
75+
/// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc.
76+
/// To access this, use ``SparkSession.catalog``.
77+
public actor Catalog: Sendable {
78+
var spark: SparkSession
79+
80+
init(spark: SparkSession) {
81+
self.spark = spark
82+
}
83+
84+
/// A helper method to create a `Spark_Connect_Catalog`-based plan.
85+
/// - Parameter f: A lambda function to create `Spark_Connect_Catalog`.
86+
/// - Returns: A ``DataFrame`` contains the result of the given catalog operation.
87+
private func getDataFrame(_ f: () -> Spark_Connect_Catalog) -> DataFrame {
88+
var relation = Relation()
89+
relation.catalog = f()
90+
var plan = Plan()
91+
plan.opType = .root(relation)
92+
return DataFrame(spark: spark, plan: plan)
93+
}
94+
95+
/// Returns the current default catalog in this session.
96+
/// - Returns: A catalog name.
97+
public func currentCatalog() async throws -> String {
98+
let df = getDataFrame({
99+
var catalog = Spark_Connect_Catalog()
100+
catalog.catType = .currentCatalog(Spark_Connect_CurrentCatalog())
101+
return catalog
102+
})
103+
return try await df.collect()[0][0]!
104+
}
105+
106+
/// Sets the current default catalog in this session.
107+
/// - Parameter catalogName: name of the catalog to set
108+
public func setCurrentCatalog(_ catalogName: String) async throws {
109+
let df = getDataFrame({
110+
var setCurrentCatalog = Spark_Connect_SetCurrentCatalog()
111+
setCurrentCatalog.catalogName = catalogName
112+
113+
var catalog = Spark_Connect_Catalog()
114+
catalog.catType = .setCurrentCatalog(setCurrentCatalog)
115+
return catalog
116+
})
117+
_ = try await df.count()
118+
}
119+
120+
/// Returns a list of catalogs in this session.
121+
/// - Returns: A list of ``CatalogMetadata``.
122+
public func listCatalogs(pattern: String? = nil) async throws -> [CatalogMetadata] {
123+
let df = getDataFrame({
124+
var listCatalogs = Spark_Connect_ListCatalogs()
125+
if let pattern {
126+
listCatalogs.pattern = pattern
127+
}
128+
var catalog = Spark_Connect_Catalog()
129+
catalog.catType = .listCatalogs(listCatalogs)
130+
return catalog
131+
})
132+
return try await df.collect().map {
133+
CatalogMetadata(name: $0[0]!, description: $0[1])
134+
}
135+
}
136+
137+
/// Returns the current default database in this session.
138+
/// - Returns: The current default database name.
139+
public func currentDatabase() async throws -> String {
140+
let df = getDataFrame({
141+
var catalog = Spark_Connect_Catalog()
142+
catalog.catType = .currentDatabase(Spark_Connect_CurrentDatabase())
143+
return catalog
144+
})
145+
return try await df.collect()[0][0]!
146+
}
147+
148+
/// Sets the current default database in this session.
149+
/// - Parameter dbName: name of the catalog to set
150+
public func setCurrentDatabase(_ dbName: String) async throws {
151+
let df = getDataFrame({
152+
var setCurrentDatabase = Spark_Connect_SetCurrentDatabase()
153+
setCurrentDatabase.dbName = dbName
154+
155+
var catalog = Spark_Connect_Catalog()
156+
catalog.catType = .setCurrentDatabase(setCurrentDatabase)
157+
return catalog
158+
})
159+
_ = try await df.count()
160+
}
161+
162+
/// Returns a list of databases available across all sessions.
163+
/// - Parameter pattern: The pattern that the database name needs to match.
164+
/// - Returns: A list of ``Database``.
165+
public func listDatabases(pattern: String? = nil) async throws -> [Database] {
166+
let df = getDataFrame({
167+
var listDatabases = Spark_Connect_ListDatabases()
168+
if let pattern {
169+
listDatabases.pattern = pattern
170+
}
171+
var catalog = Spark_Connect_Catalog()
172+
catalog.catType = .listDatabases(listDatabases)
173+
return catalog
174+
})
175+
return try await df.collect().map {
176+
Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!)
177+
}
178+
}
179+
180+
/// Get the database with the specified name.
181+
/// - Parameter dbName: name of the database to get.
182+
/// - Returns: The database found by the name.
183+
public func getDatabase(_ dbName: String) async throws -> Database {
184+
let df = getDataFrame({
185+
var db = Spark_Connect_GetDatabase()
186+
db.dbName = dbName
187+
var catalog = Spark_Connect_Catalog()
188+
catalog.catType = .getDatabase(db)
189+
return catalog
190+
})
191+
return try await df.collect().map {
192+
Database(name: $0[0]!, catalog: $0[1], description: $0[2], locationUri: $0[3]!)
193+
}.first!
194+
}
195+
196+
/// Check if the database with the specified name exists.
197+
/// - Parameter dbName: name of the database to check existence
198+
/// - Returns: Indicating whether the database exists.
199+
public func databaseExists(_ dbName: String) async throws -> Bool {
200+
return try await self.listDatabases(pattern: dbName).count > 0
201+
}
202+
}

Sources/SparkConnect/SparkSession.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ public actor SparkSession {
8484
}
8585
}
8686

87+
/// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc.
88+
public var catalog: Catalog {
89+
get {
90+
return Catalog(spark: self)
91+
}
92+
}
93+
8794
/// Stop the current client.
8895
public func stop() async {
8996
await client.stop()

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ typealias Limit = Spark_Connect_Limit
3333
typealias MapType = Spark_Connect_DataType.Map
3434
typealias NamedTable = Spark_Connect_Read.NamedTable
3535
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
36+
typealias OneOf_CatType = Spark_Connect_Catalog.OneOf_CatType
3637
typealias Plan = Spark_Connect_Plan
3738
typealias Project = Spark_Connect_Project
3839
typealias Range = Spark_Connect_Range
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import Testing
22+
23+
@testable import SparkConnect
24+
25+
/// A test suite for `Catalog`
26+
struct CatalogTests {
27+
#if !os(Linux)
28+
@Test
29+
func currentCatalog() async throws {
30+
let spark = try await SparkSession.builder.getOrCreate()
31+
#expect(try await spark.catalog.currentCatalog() == "spark_catalog")
32+
await spark.stop()
33+
}
34+
35+
@Test
36+
func setCurrentCatalog() async throws {
37+
let spark = try await SparkSession.builder.getOrCreate()
38+
try await spark.catalog.setCurrentCatalog("spark_catalog")
39+
try await #require(throws: Error.self) {
40+
try await spark.catalog.setCurrentCatalog("not_exist_catalog")
41+
}
42+
await spark.stop()
43+
}
44+
45+
@Test
46+
func listCatalogs() async throws {
47+
let spark = try await SparkSession.builder.getOrCreate()
48+
#expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")])
49+
#expect(try await spark.catalog.listCatalogs(pattern: "*") == [CatalogMetadata(name: "spark_catalog")])
50+
#expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0)
51+
await spark.stop()
52+
}
53+
54+
@Test
55+
func currentDatabase() async throws {
56+
let spark = try await SparkSession.builder.getOrCreate()
57+
#expect(try await spark.catalog.currentDatabase() == "default")
58+
await spark.stop()
59+
}
60+
61+
@Test
62+
func setCurrentDatabase() async throws {
63+
let spark = try await SparkSession.builder.getOrCreate()
64+
try await spark.catalog.setCurrentDatabase("default")
65+
try await #require(throws: Error.self) {
66+
try await spark.catalog.setCurrentDatabase("not_exist_database")
67+
}
68+
await spark.stop()
69+
}
70+
71+
@Test
72+
func listDatabases() async throws {
73+
let spark = try await SparkSession.builder.getOrCreate()
74+
let dbs = try await spark.catalog.listDatabases()
75+
#expect(dbs.count == 1)
76+
#expect(dbs[0].name == "default")
77+
#expect(dbs[0].catalog == "spark_catalog")
78+
#expect(dbs[0].description == "default database")
79+
#expect(dbs[0].locationUri.hasSuffix("spark-warehouse"))
80+
#expect(try await spark.catalog.listDatabases(pattern: "*") == dbs)
81+
#expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0)
82+
await spark.stop()
83+
}
84+
85+
@Test
86+
func getDatabase() async throws {
87+
let spark = try await SparkSession.builder.getOrCreate()
88+
let db = try await spark.catalog.getDatabase("default")
89+
#expect(db.name == "default")
90+
#expect(db.catalog == "spark_catalog")
91+
#expect(db.description == "default database")
92+
#expect(db.locationUri.hasSuffix("spark-warehouse"))
93+
try await #require(throws: Error.self) {
94+
try await spark.catalog.getDatabase("not_exist_database")
95+
}
96+
await spark.stop()
97+
}
98+
99+
@Test
100+
func databaseExists() async throws {
101+
let spark = try await SparkSession.builder.getOrCreate()
102+
#expect(try await spark.catalog.databaseExists("default"))
103+
#expect(try await spark.catalog.databaseExists("not_exist_database") == false)
104+
await spark.stop()
105+
}
106+
#endif
107+
}

0 commit comments

Comments
 (0)