Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Sources/SparkConnect/Catalog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -393,4 +393,36 @@ public actor Catalog: Sendable {
})
try await df.count()
}

/// Drops the local temporary view with the given view name in the catalog. If the view has been
/// cached before, then it will also be uncached.
/// - Parameter viewName: The name of the temporary view to be dropped.
/// - Returns: true if the view is dropped successfully, false otherwise.
@discardableResult
public func dropTempView(_ viewName: String) async throws -> Bool {
let df = getDataFrame({
var dropTempView = Spark_Connect_DropTempView()
dropTempView.viewName = viewName
var catalog = Spark_Connect_Catalog()
catalog.dropTempView = dropTempView
return catalog
})
return "true" == (try await df.collect().first!.get(0) as! String)
}

/// Drops the global temporary view with the given view name in the catalog. If the view has been
/// cached before, then it will also be uncached.
/// - Parameter viewName: The unqualified name of the temporary view to be dropped.
/// - Returns: true if the view is dropped successfully, false otherwise.
@discardableResult
public func dropGlobalTempView(_ viewName: String) async throws -> Bool {
let df = getDataFrame({
var dropGlobalTempView = Spark_Connect_DropGlobalTempView()
dropGlobalTempView.viewName = viewName
var catalog = Spark_Connect_Catalog()
catalog.dropGlobalTempView = dropGlobalTempView
return catalog
})
return "true" == (try await df.collect().first!.get(0) as! String)
}
}
32 changes: 32 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,38 @@ public actor DataFrame: Sendable {
return GroupedData(self, GroupType.cube, cols)
}

/// Creates a local temporary view using the given name. The lifetime of this temporary view is
/// tied to the `SparkSession` that was used to create this ``DataFrame``.
/// - Parameter viewName: A view name.
public func createTempView(_ viewName: String) async throws {
try await createTempView(viewName, replace: false, global: false)
}

/// Creates a local temporary view using the given name. The lifetime of this temporary view is
/// tied to the `SparkSession` that was used to create this ``DataFrame``.
/// - Parameter viewName: A view name.
public func createOrReplaceTempView(_ viewName: String) async throws {
try await createTempView(viewName, replace: true, global: false)
}

/// Creates a global temporary view using the given name. The lifetime of this temporary view is
/// tied to this Spark application, but is cross-session.
/// - Parameter viewName: A view name.
public func createGlobalTempView(_ viewName: String) async throws {
try await createTempView(viewName, replace: false, global: true)
}

/// Creates a global temporary view using the given name. The lifetime of this temporary view is
/// tied to this Spark application, but is cross-session.
/// - Parameter viewName: A view name.
public func createOrReplaceGlobalTempView(_ viewName: String) async throws {
try await createTempView(viewName, replace: true, global: true)
}

func createTempView(_ viewName: String, replace: Bool, global: Bool) async throws {
try await spark.client.createTempView(self.plan.root, viewName, replace: replace, isGlobal: global)
}

/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
public var write: DataFrameWriter {
get {
Expand Down
14 changes: 14 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,20 @@ public actor SparkConnectClient {
return plan
}

func createTempView(
_ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool
) async throws {
var viewCommand = Spark_Connect_CreateDataFrameViewCommand()
viewCommand.input = child
viewCommand.name = viewName
viewCommand.replace = replace
viewCommand.isGlobal = isGlobal

var command = Spark_Connect_Command()
command.createDataframeView = viewCommand
try await execute(self.sessionID!, command)
}

private enum URIParams {
static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size"
static let PARAM_SESSION_ID = "session_id"
Expand Down
111 changes: 111 additions & 0 deletions Tests/SparkConnectTests/CatalogTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,117 @@ struct CatalogTests {
}
await spark.stop()
}

@Test
func createTempView() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTempView(spark, viewName)({
#expect(try await spark.catalog.tableExists(viewName) == false)
try await spark.range(1).createTempView(viewName)
#expect(try await spark.catalog.tableExists(viewName))

try await #require(throws: Error.self) {
try await spark.range(1).createTempView(viewName)
}
})

try await #require(throws: Error.self) {
try await spark.range(1).createTempView("invalid view name")
}

await spark.stop()
}

@Test
func createOrReplaceTempView() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTempView(spark, viewName)({
#expect(try await spark.catalog.tableExists(viewName) == false)
try await spark.range(1).createOrReplaceTempView(viewName)
#expect(try await spark.catalog.tableExists(viewName))
try await spark.range(1).createOrReplaceTempView(viewName)
})

try await #require(throws: Error.self) {
try await spark.range(1).createOrReplaceTempView("invalid view name")
}

await spark.stop()
}

@Test
func createGlobalTempView() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withGlobalTempView(spark, viewName)({
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false)
try await spark.range(1).createGlobalTempView(viewName)
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))

try await #require(throws: Error.self) {
try await spark.range(1).createGlobalTempView(viewName)
}
})
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false)

try await #require(throws: Error.self) {
try await spark.range(1).createGlobalTempView("invalid view name")
}

await spark.stop()
}

@Test
func createOrReplaceGlobalTempView() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withGlobalTempView(spark, viewName)({
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false)
try await spark.range(1).createOrReplaceGlobalTempView(viewName)
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
try await spark.range(1).createOrReplaceGlobalTempView(viewName)
})
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false)

try await #require(throws: Error.self) {
try await spark.range(1).createOrReplaceGlobalTempView("invalid view name")
}

await spark.stop()
}

@Test
func dropTempView() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTempView(spark, viewName)({ #expect(try await spark.catalog.tableExists(viewName) == false)
try await spark.range(1).createTempView(viewName)
try await spark.catalog.dropTempView(viewName)
#expect(try await spark.catalog.tableExists(viewName) == false)
})

#expect(try await spark.catalog.dropTempView("non_exist_view") == false)
#expect(try await spark.catalog.dropTempView("invalid view name") == false)
await spark.stop()
}

@Test
func dropGlobalTempView() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTempView(spark, viewName)({ #expect(try await spark.catalog.tableExists(viewName) == false)
try await spark.range(1).createGlobalTempView(viewName)
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
try await spark.catalog.dropGlobalTempView(viewName)
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false)
})

#expect(try await spark.catalog.dropGlobalTempView("non_exist_view") == false)
#expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false)
await spark.stop()
}
#endif

@Test
Expand Down
30 changes: 30 additions & 0 deletions Tests/SparkConnectTests/SQLHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,34 @@ struct SQLHelper {
}
return body
}

public static func withTempView(_ spark: SparkSession, _ viewNames: String...) -> (
() async throws -> Void
) async throws -> Void {
func body(_ f: () async throws -> Void) async throws {
try await ErrorUtils.tryWithSafeFinally(
f,
{
for name in viewNames {
try await spark.catalog.dropTempView(name)
}
})
}
return body
}

public static func withGlobalTempView(_ spark: SparkSession, _ viewNames: String...) -> (
() async throws -> Void
) async throws -> Void {
func body(_ f: () async throws -> Void) async throws {
try await ErrorUtils.tryWithSafeFinally(
f,
{
for name in viewNames {
try await spark.catalog.dropGlobalTempView(name)
}
})
}
return body
}
}
Loading