Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,28 @@ public actor DataFrame: Sendable {
throw SparkConnectError.UnsupportedOperationException
}

/// Return an array of column name strings
/// - Returns: a string array
public func columns() async throws -> [String] {
var columns: [String] = []
try await analyzePlanIfNeeded()
for field in self.schema!.struct.fields {
columns.append(field.name)
}
return columns
}

/// Return a `JSON` string of data type because we cannot expose the internal type ``DataType``.
/// - Returns: a `JSON` string.
public func schema() async throws -> String {
var dataType: String? = nil
try await analyzePlanIfNeeded()
return try self.schema!.jsonString()
}

private func analyzePlanIfNeeded() async throws {
if self.schema != nil {
return
}
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
Expand All @@ -85,9 +102,8 @@ public actor DataFrame: Sendable {
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.analyzePlan(
spark.client.getAnalyzePlanRequest(spark.sessionID, plan))
dataType = try response.schema.schema.jsonString()
self.setSchema(response.schema.schema)
}
return dataType!
}

/// Return the total number of rows.
Expand Down Expand Up @@ -266,6 +282,8 @@ public actor DataFrame: Sendable {
return try await select().limit(1).count() == 0
}

/// Persist this `DataFrame` with the default storage level (`MEMORY_AND_DISK`).
/// - Returns: A `DataFrame`.
public func cache() async throws -> DataFrame {
return try await persist()
}
Expand All @@ -291,6 +309,10 @@ public actor DataFrame: Sendable {
return self
}

/// Mark the `DataFrame` as non-persistent, and remove all blocks for it from memory and disk.
/// This will not un-persist any cached data that is built upon this `DataFrame`.
/// - Parameter blocking: Whether to block until all blocks are deleted.
/// - Returns: A `DataFrame`
public func unpersist(blocking: Bool = false) async throws -> DataFrame {
try await withGRPCClient(
transport: .http2NIOPosix(
Expand Down
10 changes: 10 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func columns() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.sql("SELECT 1 as col1").columns() == ["col1"])
#expect(try await spark.sql("SELECT 1 as col1, 2 as col2").columns() == ["col1", "col2"])
#expect(try await spark.sql("SELECT CAST(null as STRING) col1").columns() == ["col1"])
#expect(try await spark.sql("DROP TABLE IF EXISTS nonexistent").columns() == [])
await spark.stop()
}

@Test
func schema() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading