diff --git a/Sources/SparkConnect/Catalog.swift b/Sources/SparkConnect/Catalog.swift index 51f056d..6c86f61 100644 --- a/Sources/SparkConnect/Catalog.swift +++ b/Sources/SparkConnect/Catalog.swift @@ -273,6 +273,22 @@ public actor Catalog: Sendable { return try await df.collect()[0].getAsBool(0) } + /// Returns a list of columns for the given table/view or temporary view. + /// - Parameter tableName: a qualified or unqualified name that designates a table/view. It follows the same + /// resolution rule with SQL: search for temp views first then table/views in the current + /// database (namespace). + /// - Returns: A ``DataFrame`` of ``Column``. + public func listColumns(_ tableName: String) async throws -> DataFrame { + let df = getDataFrame({ + var listColumns = Spark_Connect_ListColumns() + listColumns.tableName = tableName + var catalog = Spark_Connect_Catalog() + catalog.listColumns = listColumns + return catalog + }) + return df + } + /// Check if the function with the specified name exists. This can either be a temporary function /// or a function. /// - Parameter functionName: a qualified or unqualified name that designates a function. It follows the same diff --git a/Sources/SparkConnect/Documentation.docc/SparkSession.md b/Sources/SparkConnect/Documentation.docc/SparkSession.md index 9bd4f78..2fa4425 100644 --- a/Sources/SparkConnect/Documentation.docc/SparkSession.md +++ b/Sources/SparkConnect/Documentation.docc/SparkSession.md @@ -37,6 +37,7 @@ let csvDf = spark.read.csv("path/to/file.csv") ### DataFrame Operations +- ``emptyDataFrame`` - ``range(_:_:_:)`` - ``sql(_:)`` diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 016f89f..f1e396f 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -234,6 +234,16 @@ public actor SparkConnectClient { } } + func getLocalRelation() -> Plan { + var localRelation = Spark_Connect_LocalRelation() + localRelation.schema = "" + var relation = Relation() + relation.localRelation = localRelation + var plan = Plan() + plan.opType = .root(relation) + return plan + } + /// Create a `Plan` instance for `Range` relation. /// - Parameters: /// - start: A start of the range. diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index bb8b534..a3fcff9 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -83,6 +83,13 @@ public actor SparkSession { public func stop() async { await client.stop() } + + /// Returns a `DataFrame` with no rows or columns. + public var emptyDataFrame: DataFrame { + get async { + return await DataFrame(spark: self, plan: client.getLocalRelation()) + } + } /// Create a ``DataFrame`` with a single ``Int64`` column name `id`, containing elements in a /// range from 0 to `end` (exclusive) with step value 1. diff --git a/Tests/SparkConnectTests/CatalogTests.swift b/Tests/SparkConnectTests/CatalogTests.swift index 134a3ac..5739399 100644 --- a/Tests/SparkConnectTests/CatalogTests.swift +++ b/Tests/SparkConnectTests/CatalogTests.swift @@ -143,6 +143,40 @@ struct CatalogTests { await spark.stop() } + @Test + func listColumns() async throws { + let spark = try await SparkSession.builder.getOrCreate() + + // Table + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + let path = "/tmp/\(tableName)" + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(2).write.orc(path) + let expected = if await spark.version.starts(with: "4.") { + [Row("id", nil, "bigint", true, false, false, false)] + } else { + [Row("id", nil, "bigint", true, false, false)] + } + #expect(try await spark.catalog.createTable(tableName, path, source: "orc").count() == 2) + #expect(try await spark.catalog.listColumns(tableName).collect() == expected) + #expect(try await spark.catalog.listColumns("default.\(tableName)").collect() == expected) + }) + + // View + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + try await spark.range(1).createTempView(viewName) + let expected = if await spark.version.starts(with: "4.") { + [Row("id", nil, "bigint", false, false, false, false)] + } else { + [Row("id", nil, "bigint", false, false, false)] + } + #expect(try await spark.catalog.listColumns(viewName).collect() == expected) + }) + + await spark.stop() + } + @Test func functionExists() async throws { let spark = try await SparkSession.builder.getOrCreate() diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 46bdab5..e57e3df 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -67,6 +67,15 @@ struct SparkSessionTests { await spark.stop() } + @Test + func emptyDataFrame() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.emptyDataFrame.count() == 0) + #expect(try await spark.emptyDataFrame.dtypes.isEmpty) + #expect(try await spark.emptyDataFrame.isLocal()) + await spark.stop() + } + @Test func range() async throws { let spark = try await SparkSession.builder.getOrCreate()