diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 8732c9e..237c08c 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -71,11 +71,28 @@ public actor DataFrame: Sendable { throw SparkConnectError.UnsupportedOperationException } + /// Return an array of column name strings + /// - Returns: a string array + public func columns() async throws -> [String] { + var columns: [String] = [] + try await analyzePlanIfNeeded() + for field in self.schema!.struct.fields { + columns.append(field.name) + } + return columns + } + /// Return a `JSON` string of data type because we cannot expose the internal type ``DataType``. /// - Returns: a `JSON` string. public func schema() async throws -> String { - var dataType: String? = nil + try await analyzePlanIfNeeded() + return try self.schema!.jsonString() + } + private func analyzePlanIfNeeded() async throws { + if self.schema != nil { + return + } try await withGRPCClient( transport: .http2NIOPosix( target: .dns(host: spark.client.host, port: spark.client.port), @@ -85,9 +102,8 @@ public actor DataFrame: Sendable { let service = Spark_Connect_SparkConnectService.Client(wrapping: client) let response = try await service.analyzePlan( spark.client.getAnalyzePlanRequest(spark.sessionID, plan)) - dataType = try response.schema.schema.jsonString() + self.setSchema(response.schema.schema) } - return dataType! } /// Return the total number of rows. @@ -266,6 +282,8 @@ public actor DataFrame: Sendable { return try await select().limit(1).count() == 0 } + /// Persist this `DataFrame` with the default storage level (`MEMORY_AND_DISK`). + /// - Returns: A `DataFrame`. public func cache() async throws -> DataFrame { return try await persist() } @@ -291,6 +309,10 @@ public actor DataFrame: Sendable { return self } + /// Mark the `DataFrame` as non-persistent, and remove all blocks for it from memory and disk. + /// This will not un-persist any cached data that is built upon this `DataFrame`. + /// - Parameter blocking: Whether to block until all blocks are deleted. + /// - Returns: A `DataFrame` public func unpersist(blocking: Bool = false) async throws -> DataFrame { try await withGRPCClient( transport: .http2NIOPosix( diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index c7170d3..c49fa15 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -32,6 +32,16 @@ struct DataFrameTests { await spark.stop() } + @Test + func columns() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.sql("SELECT 1 as col1").columns() == ["col1"]) + #expect(try await spark.sql("SELECT 1 as col1, 2 as col2").columns() == ["col1", "col2"]) + #expect(try await spark.sql("SELECT CAST(null as STRING) col1").columns() == ["col1"]) + #expect(try await spark.sql("DROP TABLE IF EXISTS nonexistent").columns() == []) + await spark.stop() + } + @Test func schema() async throws { let spark = try await SparkSession.builder.getOrCreate()