Skip to content

Commit cee0edf

Browse files
committed
[SPARK-51968] Support (cache|uncache|refresh)Table, refreshByPath, isCached, clearCache in Catalog
### What changes were proposed in this pull request? This PR aims to support the following APIs of `Catalog`. - `cacheTable` - `uncacheTable` - `refreshTable` - `refreshByPath` - `isCached` - `clearCache` ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #96 from dongjoon-hyun/SPARK-51968. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent d26e32b commit cee0edf

File tree

3 files changed

+191
-0
lines changed

3 files changed

+191
-0
lines changed

Sources/SparkConnect/Catalog.swift

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,90 @@ public actor Catalog: Sendable {
199199
public func databaseExists(_ dbName: String) async throws -> Bool {
200200
return try await self.listDatabases(pattern: dbName).count > 0
201201
}
202+
203+
/// Caches the specified table in-memory.
204+
/// - Parameters:
205+
/// - tableName: A qualified or unqualified name that designates a table/view.
206+
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
207+
/// - storageLevel: storage level to cache table.
208+
public func cacheTable(_ tableName: String, _ storageLevel: StorageLevel? = nil) async throws {
209+
let df = getDataFrame({
210+
var cacheTable = Spark_Connect_CacheTable()
211+
cacheTable.tableName = tableName
212+
if let storageLevel {
213+
cacheTable.storageLevel = storageLevel.toSparkConnectStorageLevel
214+
}
215+
var catalog = Spark_Connect_Catalog()
216+
catalog.cacheTable = cacheTable
217+
return catalog
218+
})
219+
try await df.count()
220+
}
221+
222+
/// Returns true if the table is currently cached in-memory.
223+
/// - Parameter tableName: A qualified or unqualified name that designates a table/view.
224+
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
225+
public func isCached(_ tableName: String) async throws -> Bool {
226+
let df = getDataFrame({
227+
var isCached = Spark_Connect_IsCached()
228+
isCached.tableName = tableName
229+
var catalog = Spark_Connect_Catalog()
230+
catalog.isCached = isCached
231+
return catalog
232+
})
233+
return "true" == (try await df.collect().first!.get(0) as! String)
234+
}
235+
236+
/// Invalidates and refreshes all the cached data and metadata of the given table.
237+
/// - Parameter tableName: A qualified or unqualified name that designates a table/view.
238+
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
239+
public func refreshTable(_ tableName: String) async throws {
240+
let df = getDataFrame({
241+
var refreshTable = Spark_Connect_RefreshTable()
242+
refreshTable.tableName = tableName
243+
var catalog = Spark_Connect_Catalog()
244+
catalog.refreshTable = refreshTable
245+
return catalog
246+
})
247+
try await df.count()
248+
}
249+
250+
/// Invalidates and refreshes all the cached data (and the associated metadata) for any ``DataFrame``
251+
/// that contains the given data source path. Path matching is by checking for sub-directories,
252+
/// i.e. "/" would invalidate everything that is cached and "/test/parent" would invalidate
253+
/// everything that is a subdirectory of "/test/parent".
254+
public func refreshByPath(_ path: String) async throws {
255+
let df = getDataFrame({
256+
var refreshByPath = Spark_Connect_RefreshByPath()
257+
refreshByPath.path = path
258+
var catalog = Spark_Connect_Catalog()
259+
catalog.refreshByPath = refreshByPath
260+
return catalog
261+
})
262+
try await df.count()
263+
}
264+
265+
/// Removes the specified table from the in-memory cache.
266+
/// - Parameter tableName: A qualified or unqualified name that designates a table/view.
267+
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
268+
public func uncacheTable(_ tableName: String) async throws {
269+
let df = getDataFrame({
270+
var uncacheTable = Spark_Connect_UncacheTable()
271+
uncacheTable.tableName = tableName
272+
var catalog = Spark_Connect_Catalog()
273+
catalog.uncacheTable = uncacheTable
274+
return catalog
275+
})
276+
try await df.count()
277+
}
278+
279+
/// Removes all cached tables from the in-memory cache.
280+
public func clearCache() async throws {
281+
let df = getDataFrame({
282+
var catalog = Spark_Connect_Catalog()
283+
catalog.clearCache_p = Spark_Connect_ClearCache()
284+
return catalog
285+
})
286+
try await df.count()
287+
}
202288
}

Sources/SparkConnect/SparkFileUtils.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public enum SparkFileUtils {
6464
/// Create a directory given the abstract pathname
6565
/// - Parameter url: An URL location.
6666
/// - Returns: Return true if the directory is successfully created; otherwise, return false.
67+
@discardableResult
6768
static func createDirectory(at url: URL) -> Bool {
6869
let fileManager = FileManager.default
6970
do {

Tests/SparkConnectTests/CatalogTests.swift

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,108 @@ struct CatalogTests {
111111
await spark.stop()
112112
}
113113
#endif
114+
115+
@Test
116+
func cacheTable() async throws {
117+
let spark = try await SparkSession.builder.getOrCreate()
118+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
119+
try await SQLHelper.withTable(spark, tableName)({
120+
try await spark.range(1).write.saveAsTable(tableName)
121+
try await spark.catalog.cacheTable(tableName)
122+
#expect(try await spark.catalog.isCached(tableName))
123+
try await spark.catalog.cacheTable(tableName, StorageLevel.MEMORY_ONLY)
124+
})
125+
126+
try await #require(throws: Error.self) {
127+
try await spark.catalog.cacheTable("not_exist_table")
128+
}
129+
await spark.stop()
130+
}
131+
132+
@Test
133+
func isCached() async throws {
134+
let spark = try await SparkSession.builder.getOrCreate()
135+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
136+
try await SQLHelper.withTable(spark, tableName)({
137+
try await spark.range(1).write.saveAsTable(tableName)
138+
#expect(try await spark.catalog.isCached(tableName) == false)
139+
try await spark.catalog.cacheTable(tableName)
140+
#expect(try await spark.catalog.isCached(tableName))
141+
})
142+
143+
try await #require(throws: Error.self) {
144+
try await spark.catalog.isCached("not_exist_table")
145+
}
146+
await spark.stop()
147+
}
148+
149+
@Test
150+
func refreshTable() async throws {
151+
let spark = try await SparkSession.builder.getOrCreate()
152+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
153+
try await SQLHelper.withTable(spark, tableName)({
154+
try await spark.range(1).write.saveAsTable(tableName)
155+
try await spark.catalog.refreshTable(tableName)
156+
#expect(try await spark.catalog.isCached(tableName) == false)
157+
158+
try await spark.catalog.cacheTable(tableName)
159+
#expect(try await spark.catalog.isCached(tableName))
160+
try await spark.catalog.refreshTable(tableName)
161+
#expect(try await spark.catalog.isCached(tableName))
162+
})
163+
164+
try await #require(throws: Error.self) {
165+
try await spark.catalog.refreshTable("not_exist_table")
166+
}
167+
await spark.stop()
168+
}
169+
170+
@Test
171+
func refreshByPath() async throws {
172+
let spark = try await SparkSession.builder.getOrCreate()
173+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
174+
try await SQLHelper.withTable(spark, tableName)({
175+
try await spark.range(1).write.saveAsTable(tableName)
176+
try await spark.catalog.refreshByPath("/")
177+
#expect(try await spark.catalog.isCached(tableName) == false)
178+
179+
try await spark.catalog.cacheTable(tableName)
180+
#expect(try await spark.catalog.isCached(tableName))
181+
try await spark.catalog.refreshByPath("/")
182+
#expect(try await spark.catalog.isCached(tableName))
183+
})
184+
await spark.stop()
185+
}
186+
187+
@Test
188+
func uncacheTable() async throws {
189+
let spark = try await SparkSession.builder.getOrCreate()
190+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
191+
try await SQLHelper.withTable(spark, tableName)({
192+
try await spark.range(1).write.saveAsTable(tableName)
193+
try await spark.catalog.cacheTable(tableName)
194+
#expect(try await spark.catalog.isCached(tableName))
195+
try await spark.catalog.uncacheTable(tableName)
196+
#expect(try await spark.catalog.isCached(tableName) == false)
197+
})
198+
199+
try await #require(throws: Error.self) {
200+
try await spark.catalog.uncacheTable("not_exist_table")
201+
}
202+
await spark.stop()
203+
}
204+
205+
@Test
206+
func clearCache() async throws {
207+
let spark = try await SparkSession.builder.getOrCreate()
208+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
209+
try await SQLHelper.withTable(spark, tableName)({
210+
try await spark.range(1).write.saveAsTable(tableName)
211+
try await spark.catalog.cacheTable(tableName)
212+
#expect(try await spark.catalog.isCached(tableName))
213+
try await spark.catalog.clearCache()
214+
#expect(try await spark.catalog.isCached(tableName) == false)
215+
})
216+
await spark.stop()
217+
}
114218
}

0 commit comments

Comments
 (0)