Skip to content

Commit b741ae5

Browse files
committed
[SPARK-51993] Support emptyDataFrame and listColumns
### What changes were proposed in this pull request? This PR aims to support `SparkSession.emptyDataFrame` and `Catalog.listColumns` APIs. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #109 from dongjoon-hyun/SPARK-51993. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 9499253 commit b741ae5

File tree

6 files changed

+77
-0
lines changed

6 files changed

+77
-0
lines changed

Sources/SparkConnect/Catalog.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,22 @@ public actor Catalog: Sendable {
273273
return try await df.collect()[0].getAsBool(0)
274274
}
275275

276+
/// Returns a list of columns for the given table/view or temporary view.
277+
/// - Parameter tableName: a qualified or unqualified name that designates a table/view. It follows the same
278+
/// resolution rule with SQL: search for temp views first then table/views in the current
279+
/// database (namespace).
280+
/// - Returns: A ``DataFrame`` of ``Column``.
281+
public func listColumns(_ tableName: String) async throws -> DataFrame {
282+
let df = getDataFrame({
283+
var listColumns = Spark_Connect_ListColumns()
284+
listColumns.tableName = tableName
285+
var catalog = Spark_Connect_Catalog()
286+
catalog.listColumns = listColumns
287+
return catalog
288+
})
289+
return df
290+
}
291+
276292
/// Check if the function with the specified name exists. This can either be a temporary function
277293
/// or a function.
278294
/// - Parameter functionName: a qualified or unqualified name that designates a function. It follows the same

Sources/SparkConnect/Documentation.docc/SparkSession.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ let csvDf = spark.read.csv("path/to/file.csv")
3737

3838
### DataFrame Operations
3939

40+
- ``emptyDataFrame``
4041
- ``range(_:_:_:)``
4142
- ``sql(_:)``
4243

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,16 @@ public actor SparkConnectClient {
234234
}
235235
}
236236

237+
func getLocalRelation() -> Plan {
238+
var localRelation = Spark_Connect_LocalRelation()
239+
localRelation.schema = ""
240+
var relation = Relation()
241+
relation.localRelation = localRelation
242+
var plan = Plan()
243+
plan.opType = .root(relation)
244+
return plan
245+
}
246+
237247
/// Create a `Plan` instance for `Range` relation.
238248
/// - Parameters:
239249
/// - start: A start of the range.

Sources/SparkConnect/SparkSession.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ public actor SparkSession {
8383
public func stop() async {
8484
await client.stop()
8585
}
86+
87+
/// Returns a `DataFrame` with no rows or columns.
88+
public var emptyDataFrame: DataFrame {
89+
get async {
90+
return await DataFrame(spark: self, plan: client.getLocalRelation())
91+
}
92+
}
8693

8794
/// Create a ``DataFrame`` with a single ``Int64`` column name `id`, containing elements in a
8895
/// range from 0 to `end` (exclusive) with step value 1.

Tests/SparkConnectTests/CatalogTests.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,40 @@ struct CatalogTests {
143143
await spark.stop()
144144
}
145145

146+
@Test
147+
func listColumns() async throws {
148+
let spark = try await SparkSession.builder.getOrCreate()
149+
150+
// Table
151+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
152+
let path = "/tmp/\(tableName)"
153+
try await SQLHelper.withTable(spark, tableName)({
154+
try await spark.range(2).write.orc(path)
155+
let expected = if await spark.version.starts(with: "4.") {
156+
[Row("id", nil, "bigint", true, false, false, false)]
157+
} else {
158+
[Row("id", nil, "bigint", true, false, false)]
159+
}
160+
#expect(try await spark.catalog.createTable(tableName, path, source: "orc").count() == 2)
161+
#expect(try await spark.catalog.listColumns(tableName).collect() == expected)
162+
#expect(try await spark.catalog.listColumns("default.\(tableName)").collect() == expected)
163+
})
164+
165+
// View
166+
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
167+
try await SQLHelper.withTempView(spark, viewName)({
168+
try await spark.range(1).createTempView(viewName)
169+
let expected = if await spark.version.starts(with: "4.") {
170+
[Row("id", nil, "bigint", false, false, false, false)]
171+
} else {
172+
[Row("id", nil, "bigint", false, false, false)]
173+
}
174+
#expect(try await spark.catalog.listColumns(viewName).collect() == expected)
175+
})
176+
177+
await spark.stop()
178+
}
179+
146180
@Test
147181
func functionExists() async throws {
148182
let spark = try await SparkSession.builder.getOrCreate()

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ struct SparkSessionTests {
6767
await spark.stop()
6868
}
6969

70+
@Test
71+
func emptyDataFrame() async throws {
72+
let spark = try await SparkSession.builder.getOrCreate()
73+
#expect(try await spark.emptyDataFrame.count() == 0)
74+
#expect(try await spark.emptyDataFrame.dtypes.isEmpty)
75+
#expect(try await spark.emptyDataFrame.isLocal())
76+
await spark.stop()
77+
}
78+
7079
@Test
7180
func range() async throws {
7281
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)