From 005168bc9a1601a5698f92bce02cf81a0d8ddfd1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 14 Apr 2025 16:56:11 +0900 Subject: [PATCH] [SPARK-51792] Support `saveAsTable` and `insertInto` --- Sources/SparkConnect/DataFrameWriter.swift | 46 ++++++++++++++++--- .../DataFrameWriterTests.swift | 43 +++++++++++++++++ 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/Sources/SparkConnect/DataFrameWriter.swift b/Sources/SparkConnect/DataFrameWriter.swift index 6846df2..9a142a5 100644 --- a/Sources/SparkConnect/DataFrameWriter.swift +++ b/Sources/SparkConnect/DataFrameWriter.swift @@ -113,16 +113,48 @@ public actor DataFrameWriter: Sendable { } private func saveInternal(_ path: String?) async throws { - var write = WriteOperation() + try await executeWriteOperation({ + var write = WriteOperation() + if let path = path { + write.path = path + } + return write + }) + } + + /// Saves the content of the ``DataFrame`` as the specified table. + /// - Parameter tableName: A table name. + public func saveAsTable(_ tableName: String) async throws { + try await executeWriteOperation({ + var write = WriteOperation() + write.table.tableName = tableName + write.table.saveMethod = .saveAsTable + return write + }) + } + + /// Inserts the content of the ``DataFrame`` to the specified table. It requires that the schema of + /// the ``DataFrame`` is the same as the schema of the table. Unlike ``saveAsTable``, + /// ``insertInto`` ignores the column names and just uses position-based resolution. + /// - Parameter tableName: A table name. + public func insertInto(_ tableName: String) async throws { + try await executeWriteOperation({ + var write = WriteOperation() + write.table.tableName = tableName + write.table.saveMethod = .insertInto + return write + }) + } + + private func executeWriteOperation(_ f: () -> WriteOperation) async throws { + var write = f() + + // Cannot both be set + assert(!(!write.path.isEmpty && !write.table.tableName.isEmpty)) + let plan = await self.df.getPlan() as! Plan write.input = plan.root write.mode = self.saveMode.toSaveMode - if let path = path { - write.path = path - } - - // Cannot both be set - // require(!(builder.hasPath && builder.hasTable)) if let source = self.source { write.source = source diff --git a/Tests/SparkConnectTests/DataFrameWriterTests.swift b/Tests/SparkConnectTests/DataFrameWriterTests.swift index d7fde78..da6d190 100644 --- a/Tests/SparkConnectTests/DataFrameWriterTests.swift +++ b/Tests/SparkConnectTests/DataFrameWriterTests.swift @@ -101,6 +101,49 @@ struct DataFrameWriterTests { await spark.stop() } + @Test + func saveAsTable() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(1).write.saveAsTable(tableName) + #expect(try await spark.read.table(tableName).count() == 1) + + try await #require(throws: Error.self) { + try await spark.range(1).write.saveAsTable(tableName) + } + + try await spark.range(1).write.mode("overwrite").saveAsTable(tableName) + #expect(try await spark.read.table(tableName).count() == 1) + + try await spark.range(1).write.mode("append").saveAsTable(tableName) + #expect(try await spark.read.table(tableName).count() == 2) + }) + await spark.stop() + } + + @Test + func insertInto() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + // Table doesn't exist. + try await #require(throws: Error.self) { + try await spark.range(1).write.insertInto(tableName) + } + + try await spark.range(1).write.saveAsTable(tableName) + #expect(try await spark.read.table(tableName).count() == 1) + + try await spark.range(1).write.insertInto(tableName) + #expect(try await spark.read.table(tableName).count() == 2) + + try await spark.range(1).write.insertInto(tableName) + #expect(try await spark.read.table(tableName).count() == 3) + }) + await spark.stop() + } + @Test func partitionBy() async throws { let tmpDir = "/tmp/" + UUID().uuidString