Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import Synchronization
public actor DataFrame: Sendable {
var spark: SparkSession
var plan: Plan
var schema: DataType? = nil
private var _schema: DataType? = nil
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is renamed to avoid a conflict with public var schema.

private var batches: [RecordBatch] = [RecordBatch]()

/// Create a new `DataFrame`instance with the given Spark session and plan.
Expand Down Expand Up @@ -57,7 +57,7 @@ public actor DataFrame: Sendable {
/// Set the schema. This is used to store the analized schema response from `Spark Connect` server.
/// - Parameter schema: <#schema description#>
private func setSchema(_ schema: DataType) {
self.schema = schema
self._schema = schema
}

/// Add `Apache Arrow`'s `RecordBatch`s to the internal array.
Expand All @@ -67,9 +67,10 @@ public actor DataFrame: Sendable {
}

/// Return the `SparkSession` of this `DataFrame`.
/// - Returns: A `SparkSession`
public func sparkSession() -> SparkSession {
return self.spark
public var sparkSession: SparkSession {
get async throws {
return self.spark
}
}

/// A method to access the underlying Spark's `RDD`.
Expand All @@ -82,32 +83,35 @@ public actor DataFrame: Sendable {
}

/// Return an array of column name strings
/// - Returns: a string array
public func columns() async throws -> [String] {
var columns: [String] = []
try await analyzePlanIfNeeded()
for field in self.schema!.struct.fields {
columns.append(field.name)
public var columns: [String] {
get async throws {
var columns: [String] = []
try await analyzePlanIfNeeded()
for field in self._schema!.struct.fields {
columns.append(field.name)
}
return columns
}
return columns
}

/// Return a `JSON` string of data type because we cannot expose the internal type ``DataType``.
/// - Returns: a `JSON` string.
public func schema() async throws -> String {
try await analyzePlanIfNeeded()
return try self.schema!.jsonString()
public var schema: String {
get async throws {
try await analyzePlanIfNeeded()
return try self._schema!.jsonString()
}
}

var dtypes: [(String, String)] {
/// Returns all column names and their data types as an array.
public var dtypes: [(String, String)] {
get async throws {
try await analyzePlanIfNeeded()
return try self.schema!.struct.fields.map { ($0.name, try $0.dataType.simpleString) }
return try self._schema!.struct.fields.map { ($0.name, try $0.dataType.simpleString) }
}
}

private func analyzePlanIfNeeded() async throws {
if self.schema != nil {
if self._schema != nil {
return
}
try await withGRPCClient(
Expand Down Expand Up @@ -224,7 +228,7 @@ public actor DataFrame: Sendable {
public func show() async throws {
try await execute()

if let schema = self.schema {
if let schema = self._schema {
var columns: [TextTableColumn] = []
for f in schema.struct.fields {
columns.append(TextTableColumn(header: f.name))
Expand Down Expand Up @@ -342,7 +346,7 @@ public actor DataFrame: Sendable {
return self
}

var storageLevel: StorageLevel {
public var storageLevel: StorageLevel {
get async throws {
try await withGRPCClient(
transport: .http2NIOPosix(
Expand Down Expand Up @@ -403,7 +407,7 @@ public actor DataFrame: Sendable {
}

/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
var write: DataFrameWriter {
public var write: DataFrameWriter {
get {
return DataFrameWriter(df: self)
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public actor SparkSession {
var serverSideSessionID: String = ""

/// A variable for ``SparkContext``. This is designed to throw exceptions by Apache Spark.
var sparkContext: SparkContext {
public var sparkContext: SparkContext {
get throws {
// SQLSTATE: 0A000
// [UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT]
Expand Down Expand Up @@ -119,7 +119,7 @@ public actor SparkSession {

/// Returns a ``DataFrameReader`` that can be used to read non-streaming data in as a
/// `DataFrame`
var read: DataFrameReader {
public var read: DataFrameReader {
get {
return DataFrameReader(sparkSession: self)
}
Expand All @@ -140,7 +140,7 @@ public actor SparkSession {
/// This is defined as the return type of `SparkSession.sparkContext` method.
/// This is an empty `Struct` type because `sparkContext` method is designed to throw
/// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`.
struct SparkContext {
public struct SparkContext: Sendable {
}

/// A builder to create ``SparkSession``
Expand Down
24 changes: 12 additions & 12 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct DataFrameTests {
@Test
func sparkSession() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(1).sparkSession() == spark)
#expect(try await spark.range(1).sparkSession == spark)
await spark.stop()
}

Expand All @@ -42,30 +42,30 @@ struct DataFrameTests {
@Test
func columns() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.sql("SELECT 1 as col1").columns() == ["col1"])
#expect(try await spark.sql("SELECT 1 as col1, 2 as col2").columns() == ["col1", "col2"])
#expect(try await spark.sql("SELECT CAST(null as STRING) col1").columns() == ["col1"])
#expect(try await spark.sql("DROP TABLE IF EXISTS nonexistent").columns() == [])
#expect(try await spark.sql("SELECT 1 as col1").columns == ["col1"])
#expect(try await spark.sql("SELECT 1 as col1, 2 as col2").columns == ["col1", "col2"])
#expect(try await spark.sql("SELECT CAST(null as STRING) col1").columns == ["col1"])
#expect(try await spark.sql("DROP TABLE IF EXISTS nonexistent").columns == [])
await spark.stop()
}

@Test
func schema() async throws {
let spark = try await SparkSession.builder.getOrCreate()

let schema1 = try await spark.sql("SELECT 'a' as col1").schema()
let schema1 = try await spark.sql("SELECT 'a' as col1").schema
#expect(
schema1
== #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"#
)

let schema2 = try await spark.sql("SELECT 'a' as col1, 'b' as col2").schema()
let schema2 = try await spark.sql("SELECT 'a' as col1, 'b' as col2").schema
#expect(
schema2
== #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"#
)

let emptySchema = try await spark.sql("DROP TABLE IF EXISTS nonexistent").schema()
let emptySchema = try await spark.sql("DROP TABLE IF EXISTS nonexistent").schema
#expect(emptySchema == #"{"struct":{}}"#)
await spark.stop()
}
Expand Down Expand Up @@ -136,15 +136,15 @@ struct DataFrameTests {
@Test
func selectNone() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let emptySchema = try await spark.range(1).select().schema()
let emptySchema = try await spark.range(1).select().schema
#expect(emptySchema == #"{"struct":{}}"#)
await spark.stop()
}

@Test
func select() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let schema = try await spark.range(1).select("id").schema()
let schema = try await spark.range(1).select("id").schema
#expect(
schema
== #"{"struct":{"fields":[{"name":"id","dataType":{"long":{}}}]}}"#
Expand All @@ -155,7 +155,7 @@ struct DataFrameTests {
@Test
func selectMultipleColumns() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema()
let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema
#expect(
schema
== #"{"struct":{"fields":[{"name":"col2","dataType":{"integer":{}}},{"name":"col1","dataType":{"integer":{}}}]}}"#
Expand All @@ -167,7 +167,7 @@ struct DataFrameTests {
func selectInvalidColumn() async throws {
let spark = try await SparkSession.builder.getOrCreate()
try await #require(throws: Error.self) {
let _ = try await spark.range(1).select("invalid").schema()
let _ = try await spark.range(1).select("invalid").schema
}
await spark.stop()
}
Expand Down
Loading