Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions Sources/SparkConnect/Catalog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,90 @@ public actor Catalog: Sendable {
public func databaseExists(_ dbName: String) async throws -> Bool {
return try await self.listDatabases(pattern: dbName).count > 0
}

/// Caches the specified table in-memory.
/// - Parameters:
/// - tableName: A qualified or unqualified name that designates a table/view.
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
/// - storageLevel: storage level to cache table.
public func cacheTable(_ tableName: String, _ storageLevel: StorageLevel? = nil) async throws {
let df = getDataFrame({
var cacheTable = Spark_Connect_CacheTable()
cacheTable.tableName = tableName
if let storageLevel {
cacheTable.storageLevel = storageLevel.toSparkConnectStorageLevel
}
var catalog = Spark_Connect_Catalog()
catalog.cacheTable = cacheTable
return catalog
})
try await df.count()
}

/// Returns true if the table is currently cached in-memory.
/// - Parameter tableName: A qualified or unqualified name that designates a table/view.
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
public func isCached(_ tableName: String) async throws -> Bool {
let df = getDataFrame({
var isCached = Spark_Connect_IsCached()
isCached.tableName = tableName
var catalog = Spark_Connect_Catalog()
catalog.isCached = isCached
return catalog
})
return "true" == (try await df.collect().first!.get(0) as! String)
}

/// Invalidates and refreshes all the cached data and metadata of the given table.
/// - Parameter tableName: A qualified or unqualified name that designates a table/view.
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
public func refreshTable(_ tableName: String) async throws {
let df = getDataFrame({
var refreshTable = Spark_Connect_RefreshTable()
refreshTable.tableName = tableName
var catalog = Spark_Connect_Catalog()
catalog.refreshTable = refreshTable
return catalog
})
try await df.count()
}

/// Invalidates and refreshes all the cached data (and the associated metadata) for any ``DataFrame``
/// that contains the given data source path. Path matching is by checking for sub-directories,
/// i.e. "/" would invalidate everything that is cached and "/test/parent" would invalidate
/// everything that is a subdirectory of "/test/parent".
public func refreshByPath(_ path: String) async throws {
let df = getDataFrame({
var refreshByPath = Spark_Connect_RefreshByPath()
refreshByPath.path = path
var catalog = Spark_Connect_Catalog()
catalog.refreshByPath = refreshByPath
return catalog
})
try await df.count()
}

/// Removes the specified table from the in-memory cache.
/// - Parameter tableName: A qualified or unqualified name that designates a table/view.
/// If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
public func uncacheTable(_ tableName: String) async throws {
let df = getDataFrame({
var uncacheTable = Spark_Connect_UncacheTable()
uncacheTable.tableName = tableName
var catalog = Spark_Connect_Catalog()
catalog.uncacheTable = uncacheTable
return catalog
})
try await df.count()
}

/// Removes all cached tables from the in-memory cache.
public func clearCache() async throws {
let df = getDataFrame({
var catalog = Spark_Connect_Catalog()
catalog.clearCache_p = Spark_Connect_ClearCache()
return catalog
})
try await df.count()
}
}
1 change: 1 addition & 0 deletions Sources/SparkConnect/SparkFileUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public enum SparkFileUtils {
/// Create a directory given the abstract pathname
/// - Parameter url: An URL location.
/// - Returns: Return true if the directory is successfully created; otherwise, return false.
@discardableResult
static func createDirectory(at url: URL) -> Bool {
let fileManager = FileManager.default
do {
Expand Down
104 changes: 104 additions & 0 deletions Tests/SparkConnectTests/CatalogTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,108 @@ struct CatalogTests {
await spark.stop()
}
#endif

@Test
func cacheTable() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.saveAsTable(tableName)
try await spark.catalog.cacheTable(tableName)
#expect(try await spark.catalog.isCached(tableName))
try await spark.catalog.cacheTable(tableName, StorageLevel.MEMORY_ONLY)
})

try await #require(throws: Error.self) {
try await spark.catalog.cacheTable("not_exist_table")
}
await spark.stop()
}

@Test
func isCached() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.saveAsTable(tableName)
#expect(try await spark.catalog.isCached(tableName) == false)
try await spark.catalog.cacheTable(tableName)
#expect(try await spark.catalog.isCached(tableName))
})

try await #require(throws: Error.self) {
try await spark.catalog.isCached("not_exist_table")
}
await spark.stop()
}

@Test
func refreshTable() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.saveAsTable(tableName)
try await spark.catalog.refreshTable(tableName)
#expect(try await spark.catalog.isCached(tableName) == false)

try await spark.catalog.cacheTable(tableName)
#expect(try await spark.catalog.isCached(tableName))
try await spark.catalog.refreshTable(tableName)
#expect(try await spark.catalog.isCached(tableName))
})

try await #require(throws: Error.self) {
try await spark.catalog.refreshTable("not_exist_table")
}
await spark.stop()
}

@Test
func refreshByPath() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.saveAsTable(tableName)
try await spark.catalog.refreshByPath("/")
#expect(try await spark.catalog.isCached(tableName) == false)

try await spark.catalog.cacheTable(tableName)
#expect(try await spark.catalog.isCached(tableName))
try await spark.catalog.refreshByPath("/")
#expect(try await spark.catalog.isCached(tableName))
})
await spark.stop()
}

@Test
func uncacheTable() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.saveAsTable(tableName)
try await spark.catalog.cacheTable(tableName)
#expect(try await spark.catalog.isCached(tableName))
try await spark.catalog.uncacheTable(tableName)
#expect(try await spark.catalog.isCached(tableName) == false)
})

try await #require(throws: Error.self) {
try await spark.catalog.uncacheTable("not_exist_table")
}
await spark.stop()
}

@Test
func clearCache() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.saveAsTable(tableName)
try await spark.catalog.cacheTable(tableName)
#expect(try await spark.catalog.isCached(tableName))
try await spark.catalog.clearCache()
#expect(try await spark.catalog.isCached(tableName) == false)
})
await spark.stop()
}
}
Loading