Skip to content

Commit 91d6c98

Browse files
committed
[SPARK-51719] Support table for SparkSession and DataFrameReader
### What changes were proposed in this pull request? This PR aims to support `table` API for `SparkSession` and `DataFrameReader`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No, this is a new addition to the unreleased version. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42 from dongjoon-hyun/SPARK-51719. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 18c3d18 commit 91d6c98

File tree

5 files changed

+62
-0
lines changed

5 files changed

+62
-0
lines changed

Sources/SparkConnect/DataFrameReader.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,33 @@ public actor DataFrameReader: Sendable {
4040
self.sparkSession = sparkSession
4141
}
4242

43+
/// Returns the specified table/view as a ``DataFrame``. If it's a table, it must support batch
44+
/// reading and the returned ``DataFrame`` is the batch scan query plan of this table. If it's a
45+
/// view, the returned ``DataFrame`` is simply the query plan of the view, which can either be a
46+
/// batch or streaming query plan.
47+
///
48+
/// - Parameter tableName: a qualified or unqualified name that designates a table or view. If a database is
49+
/// specified, it identifies the table/view from the database. Otherwise, it first attempts to
50+
/// find a temporary view with the given name and then match the table/view from the current
51+
/// database. Note that, the global temporary view database is also valid here.
52+
/// - Returns: A ``DataFrame`` instance.
53+
public func table(_ tableName: String) -> DataFrame {
54+
var namedTable = NamedTable()
55+
namedTable.unparsedIdentifier = tableName
56+
namedTable.options = self.extraOptions.toStringDictionary()
57+
58+
var read = Read()
59+
read.namedTable = namedTable
60+
61+
var relation = Relation()
62+
relation.read = read
63+
64+
var plan = Plan()
65+
plan.opType = .root(relation)
66+
67+
return DataFrame(spark: sparkSession, plan: plan)
68+
}
69+
4370
/// Specifies the input data source format.
4471
/// - Parameter source: A string.
4572
/// - Returns: A `DataFrameReader`.

Sources/SparkConnect/SparkSession.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ public actor SparkSession {
125125
}
126126
}
127127

128+
/// Returns the specified table/view as a ``DataFrame``. If it's a table, it must support batch
129+
/// reading and the returned ``DataFrame`` is the batch scan query plan of this table. If it's a
130+
/// view, the returned ``DataFrame`` is simply the query plan of the view, which can either be a
131+
/// batch or streaming query plan.
132+
///
133+
/// - Parameter tableName: a qualified or unqualified name that designates a table or view. If a database is
134+
/// specified, it identifies the table/view from the database. Otherwise, it first attempts to
135+
/// find a temporary view with the given name and then match the table/view from the current
136+
/// database. Note that, the global temporary view database is also valid here.
137+
/// - Returns: A ``DataFrame`` instance.
138+
public func table(_ tableName: String) async throws -> DataFrame {
139+
return await read.table(tableName)
140+
}
141+
128142
/// Executes some code block and prints to stdout the time taken to execute the block.
129143
/// - Parameter f: A function to execute.
130144
/// - Returns: The result of the executed code.

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ typealias Filter = Spark_Connect_Filter
3131
typealias KeyValue = Spark_Connect_KeyValue
3232
typealias Limit = Spark_Connect_Limit
3333
typealias MapType = Spark_Connect_DataType.Map
34+
typealias NamedTable = Spark_Connect_Read.NamedTable
3435
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
3536
typealias Plan = Spark_Connect_Plan
3637
typealias Project = Spark_Connect_Project

Tests/SparkConnectTests/DataFrameReaderTests.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,14 @@ struct DataFrameReaderTests {
6464
#expect(try await spark.read.parquet(path, path).count() == 4)
6565
await spark.stop()
6666
}
67+
68+
@Test
69+
func table() async throws {
70+
let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
71+
let spark = try await SparkSession.builder.getOrCreate()
72+
#expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0)
73+
#expect(try await spark.read.table(tableName).count() == 2)
74+
#expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
75+
await spark.stop()
76+
}
6777
}

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ struct SparkSessionTests {
7575
await spark.stop()
7676
}
7777

78+
@Test
79+
func table() async throws {
80+
let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
81+
let spark = try await SparkSession.builder.getOrCreate()
82+
#expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0)
83+
#expect(try await spark.table(tableName).count() == 2)
84+
#expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
85+
await spark.stop()
86+
}
87+
7888
@Test
7989
func time() async throws {
8090
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)