From d769ec2e7782a97afa3ef65db974d8762bd1599a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 Apr 2025 11:09:50 +0900 Subject: [PATCH] [SPARK-51702] Revise `sparkSession/read/write/columns/schema/dtypes/storageLevel` API --- Sources/SparkConnect/DataFrame.swift | 48 +++++++++++--------- Sources/SparkConnect/SparkSession.swift | 6 +-- Tests/SparkConnectTests/DataFrameTests.swift | 24 +++++----- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 96c36be..b1e831e 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -29,7 +29,7 @@ import Synchronization public actor DataFrame: Sendable { var spark: SparkSession var plan: Plan - var schema: DataType? = nil + private var _schema: DataType? = nil private var batches: [RecordBatch] = [RecordBatch]() /// Create a new `DataFrame`instance with the given Spark session and plan. @@ -57,7 +57,7 @@ public actor DataFrame: Sendable { /// Set the schema. This is used to store the analized schema response from `Spark Connect` server. /// - Parameter schema: <#schema description#> private func setSchema(_ schema: DataType) { - self.schema = schema + self._schema = schema } /// Add `Apache Arrow`'s `RecordBatch`s to the internal array. @@ -67,9 +67,10 @@ public actor DataFrame: Sendable { } /// Return the `SparkSession` of this `DataFrame`. - /// - Returns: A `SparkSession` - public func sparkSession() -> SparkSession { - return self.spark + public var sparkSession: SparkSession { + get async throws { + return self.spark + } } /// A method to access the underlying Spark's `RDD`. @@ -82,32 +83,35 @@ public actor DataFrame: Sendable { } /// Return an array of column name strings - /// - Returns: a string array - public func columns() async throws -> [String] { - var columns: [String] = [] - try await analyzePlanIfNeeded() - for field in self.schema!.struct.fields { - columns.append(field.name) + public var columns: [String] { + get async throws { + var columns: [String] = [] + try await analyzePlanIfNeeded() + for field in self._schema!.struct.fields { + columns.append(field.name) + } + return columns } - return columns } /// Return a `JSON` string of data type because we cannot expose the internal type ``DataType``. - /// - Returns: a `JSON` string. - public func schema() async throws -> String { - try await analyzePlanIfNeeded() - return try self.schema!.jsonString() + public var schema: String { + get async throws { + try await analyzePlanIfNeeded() + return try self._schema!.jsonString() + } } - var dtypes: [(String, String)] { + /// Returns all column names and their data types as an array. + public var dtypes: [(String, String)] { get async throws { try await analyzePlanIfNeeded() - return try self.schema!.struct.fields.map { ($0.name, try $0.dataType.simpleString) } + return try self._schema!.struct.fields.map { ($0.name, try $0.dataType.simpleString) } } } private func analyzePlanIfNeeded() async throws { - if self.schema != nil { + if self._schema != nil { return } try await withGRPCClient( @@ -224,7 +228,7 @@ public actor DataFrame: Sendable { public func show() async throws { try await execute() - if let schema = self.schema { + if let schema = self._schema { var columns: [TextTableColumn] = [] for f in schema.struct.fields { columns.append(TextTableColumn(header: f.name)) @@ -342,7 +346,7 @@ public actor DataFrame: Sendable { return self } - var storageLevel: StorageLevel { + public var storageLevel: StorageLevel { get async throws { try await withGRPCClient( transport: .http2NIOPosix( @@ -403,7 +407,7 @@ public actor DataFrame: Sendable { } /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. - var write: DataFrameWriter { + public var write: DataFrameWriter { get { return DataFrameWriter(df: self) } diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 8a61d86..39b6bbe 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -75,7 +75,7 @@ public actor SparkSession { var serverSideSessionID: String = "" /// A variable for ``SparkContext``. This is designed to throw exceptions by Apache Spark. - var sparkContext: SparkContext { + public var sparkContext: SparkContext { get throws { // SQLSTATE: 0A000 // [UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT] @@ -119,7 +119,7 @@ public actor SparkSession { /// Returns a ``DataFrameReader`` that can be used to read non-streaming data in as a /// `DataFrame` - var read: DataFrameReader { + public var read: DataFrameReader { get { return DataFrameReader(sparkSession: self) } @@ -140,7 +140,7 @@ public actor SparkSession { /// This is defined as the return type of `SparkSession.sparkContext` method. /// This is an empty `Struct` type because `sparkContext` method is designed to throw /// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`. - struct SparkContext { + public struct SparkContext: Sendable { } /// A builder to create ``SparkSession`` diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 7e903bd..832b031 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -26,7 +26,7 @@ struct DataFrameTests { @Test func sparkSession() async throws { let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(1).sparkSession() == spark) + #expect(try await spark.range(1).sparkSession == spark) await spark.stop() } @@ -42,10 +42,10 @@ struct DataFrameTests { @Test func columns() async throws { let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.sql("SELECT 1 as col1").columns() == ["col1"]) - #expect(try await spark.sql("SELECT 1 as col1, 2 as col2").columns() == ["col1", "col2"]) - #expect(try await spark.sql("SELECT CAST(null as STRING) col1").columns() == ["col1"]) - #expect(try await spark.sql("DROP TABLE IF EXISTS nonexistent").columns() == []) + #expect(try await spark.sql("SELECT 1 as col1").columns == ["col1"]) + #expect(try await spark.sql("SELECT 1 as col1, 2 as col2").columns == ["col1", "col2"]) + #expect(try await spark.sql("SELECT CAST(null as STRING) col1").columns == ["col1"]) + #expect(try await spark.sql("DROP TABLE IF EXISTS nonexistent").columns == []) await spark.stop() } @@ -53,19 +53,19 @@ struct DataFrameTests { func schema() async throws { let spark = try await SparkSession.builder.getOrCreate() - let schema1 = try await spark.sql("SELECT 'a' as col1").schema() + let schema1 = try await spark.sql("SELECT 'a' as col1").schema #expect( schema1 == #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# ) - let schema2 = try await spark.sql("SELECT 'a' as col1, 'b' as col2").schema() + let schema2 = try await spark.sql("SELECT 'a' as col1, 'b' as col2").schema #expect( schema2 == #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# ) - let emptySchema = try await spark.sql("DROP TABLE IF EXISTS nonexistent").schema() + let emptySchema = try await spark.sql("DROP TABLE IF EXISTS nonexistent").schema #expect(emptySchema == #"{"struct":{}}"#) await spark.stop() } @@ -136,7 +136,7 @@ struct DataFrameTests { @Test func selectNone() async throws { let spark = try await SparkSession.builder.getOrCreate() - let emptySchema = try await spark.range(1).select().schema() + let emptySchema = try await spark.range(1).select().schema #expect(emptySchema == #"{"struct":{}}"#) await spark.stop() } @@ -144,7 +144,7 @@ struct DataFrameTests { @Test func select() async throws { let spark = try await SparkSession.builder.getOrCreate() - let schema = try await spark.range(1).select("id").schema() + let schema = try await spark.range(1).select("id").schema #expect( schema == #"{"struct":{"fields":[{"name":"id","dataType":{"long":{}}}]}}"# @@ -155,7 +155,7 @@ struct DataFrameTests { @Test func selectMultipleColumns() async throws { let spark = try await SparkSession.builder.getOrCreate() - let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema() + let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema #expect( schema == #"{"struct":{"fields":[{"name":"col2","dataType":{"integer":{}}},{"name":"col1","dataType":{"integer":{}}}]}}"# @@ -167,7 +167,7 @@ struct DataFrameTests { func selectInvalidColumn() async throws { let spark = try await SparkSession.builder.getOrCreate() try await #require(throws: Error.self) { - let _ = try await spark.range(1).select("invalid").schema() + let _ = try await spark.range(1).select("invalid").schema } await spark.stop() }