Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions Sources/SparkConnect/Catalog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ public actor Catalog: Sendable {
return try await df.collect()[0].getAsBool(0)
}

/// Returns a list of columns for the given table/view or temporary view.
/// - Parameter tableName: a qualified or unqualified name that designates a table/view. It follows the same
/// resolution rule with SQL: search for temp views first then table/views in the current
/// database (namespace).
/// - Returns: A ``DataFrame`` of ``Column``.
public func listColumns(_ tableName: String) async throws -> DataFrame {
let df = getDataFrame({
var listColumns = Spark_Connect_ListColumns()
listColumns.tableName = tableName
var catalog = Spark_Connect_Catalog()
catalog.listColumns = listColumns
return catalog
})
return df
}

/// Check if the function with the specified name exists. This can either be a temporary function
/// or a function.
/// - Parameter functionName: a qualified or unqualified name that designates a function. It follows the same
Expand Down
1 change: 1 addition & 0 deletions Sources/SparkConnect/Documentation.docc/SparkSession.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ let csvDf = spark.read.csv("path/to/file.csv")

### DataFrame Operations

- ``emptyDataFrame``
- ``range(_:_:_:)``
- ``sql(_:)``

Expand Down
10 changes: 10 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,16 @@ public actor SparkConnectClient {
}
}

func getLocalRelation() -> Plan {
var localRelation = Spark_Connect_LocalRelation()
localRelation.schema = ""
var relation = Relation()
relation.localRelation = localRelation
var plan = Plan()
plan.opType = .root(relation)
return plan
}

/// Create a `Plan` instance for `Range` relation.
/// - Parameters:
/// - start: A start of the range.
Expand Down
7 changes: 7 additions & 0 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ public actor SparkSession {
public func stop() async {
await client.stop()
}

/// Returns a `DataFrame` with no rows or columns.
public var emptyDataFrame: DataFrame {
get async {
return await DataFrame(spark: self, plan: client.getLocalRelation())
}
}

/// Create a ``DataFrame`` with a single ``Int64`` column name `id`, containing elements in a
/// range from 0 to `end` (exclusive) with step value 1.
Expand Down
34 changes: 34 additions & 0 deletions Tests/SparkConnectTests/CatalogTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,40 @@ struct CatalogTests {
await spark.stop()
}

@Test
func listColumns() async throws {
let spark = try await SparkSession.builder.getOrCreate()

// Table
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
let path = "/tmp/\(tableName)"
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(2).write.orc(path)
let expected = if await spark.version.starts(with: "4.") {
[Row("id", nil, "bigint", true, false, false, false)]
} else {
[Row("id", nil, "bigint", true, false, false)]
}
#expect(try await spark.catalog.createTable(tableName, path, source: "orc").count() == 2)
#expect(try await spark.catalog.listColumns(tableName).collect() == expected)
#expect(try await spark.catalog.listColumns("default.\(tableName)").collect() == expected)
})

// View
let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTempView(spark, viewName)({
try await spark.range(1).createTempView(viewName)
let expected = if await spark.version.starts(with: "4.") {
[Row("id", nil, "bigint", false, false, false, false)]
} else {
[Row("id", nil, "bigint", false, false, false)]
}
#expect(try await spark.catalog.listColumns(viewName).collect() == expected)
})

await spark.stop()
}

@Test
func functionExists() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
9 changes: 9 additions & 0 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ struct SparkSessionTests {
await spark.stop()
}

@Test
func emptyDataFrame() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.emptyDataFrame.count() == 0)
#expect(try await spark.emptyDataFrame.dtypes.isEmpty)
#expect(try await spark.emptyDataFrame.isLocal())
await spark.stop()
}

@Test
func range() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading