diff --git a/Sources/SparkConnect/DataFrameReader.swift b/Sources/SparkConnect/DataFrameReader.swift index 47f022f..cfa41f9 100644 --- a/Sources/SparkConnect/DataFrameReader.swift +++ b/Sources/SparkConnect/DataFrameReader.swift @@ -40,6 +40,33 @@ public actor DataFrameReader: Sendable { self.sparkSession = sparkSession } + /// Returns the specified table/view as a ``DataFrame``. If it's a table, it must support batch + /// reading and the returned ``DataFrame`` is the batch scan query plan of this table. If it's a + /// view, the returned ``DataFrame`` is simply the query plan of the view, which can either be a + /// batch or streaming query plan. + /// + /// - Parameter tableName: a qualified or unqualified name that designates a table or view. If a database is + /// specified, it identifies the table/view from the database. Otherwise, it first attempts to + /// find a temporary view with the given name and then match the table/view from the current + /// database. Note that, the global temporary view database is also valid here. + /// - Returns: A ``DataFrame`` instance. + public func table(_ tableName: String) -> DataFrame { + var namedTable = NamedTable() + namedTable.unparsedIdentifier = tableName + namedTable.options = self.extraOptions.toStringDictionary() + + var read = Read() + read.namedTable = namedTable + + var relation = Relation() + relation.read = read + + var plan = Plan() + plan.opType = .root(relation) + + return DataFrame(spark: sparkSession, plan: plan) + } + /// Specifies the input data source format. /// - Parameter source: A string. /// - Returns: A `DataFrameReader`. diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 39b6bbe..2fa583c 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -125,6 +125,20 @@ public actor SparkSession { } } + /// Returns the specified table/view as a ``DataFrame``. If it's a table, it must support batch + /// reading and the returned ``DataFrame`` is the batch scan query plan of this table. If it's a + /// view, the returned ``DataFrame`` is simply the query plan of the view, which can either be a + /// batch or streaming query plan. + /// + /// - Parameter tableName: a qualified or unqualified name that designates a table or view. If a database is + /// specified, it identifies the table/view from the database. Otherwise, it first attempts to + /// find a temporary view with the given name and then match the table/view from the current + /// database. Note that, the global temporary view database is also valid here. + /// - Returns: A ``DataFrame`` instance. + public func table(_ tableName: String) async throws -> DataFrame { + return await read.table(tableName) + } + /// Executes some code block and prints to stdout the time taken to execute the block. /// - Parameter f: A function to execute. /// - Returns: The result of the executed code. diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index f82c4f5..df1fe80 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -31,6 +31,7 @@ typealias Filter = Spark_Connect_Filter typealias KeyValue = Spark_Connect_KeyValue typealias Limit = Spark_Connect_Limit typealias MapType = Spark_Connect_DataType.Map +typealias NamedTable = Spark_Connect_Read.NamedTable typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze typealias Plan = Spark_Connect_Plan typealias Project = Spark_Connect_Project diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift b/Tests/SparkConnectTests/DataFrameReaderTests.swift index c159b8f..5c0979d 100644 --- a/Tests/SparkConnectTests/DataFrameReaderTests.swift +++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift @@ -64,4 +64,14 @@ struct DataFrameReaderTests { #expect(try await spark.read.parquet(path, path).count() == 4) await spark.stop() } + + @Test + func table() async throws { + let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "") + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0) + #expect(try await spark.read.table(tableName).count() == 2) + #expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0) + await spark.stop() + } } diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index f302349..cba57e4 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -75,6 +75,16 @@ struct SparkSessionTests { await spark.stop() } + @Test + func table() async throws { + let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "") + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0) + #expect(try await spark.table(tableName).count() == 2) + #expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0) + await spark.stop() + } + @Test func time() async throws { let spark = try await SparkSession.builder.getOrCreate()