diff --git a/Sources/SparkConnect/DataFrameReader.swift b/Sources/SparkConnect/DataFrameReader.swift index cfa41f9..ca00b07 100644 --- a/Sources/SparkConnect/DataFrameReader.swift +++ b/Sources/SparkConnect/DataFrameReader.swift @@ -160,6 +160,22 @@ public actor DataFrameReader: Sendable { return load(paths) } + /// Loads an XML file and returns the result as a `DataFrame`. + /// - Parameter path: A path string + /// - Returns: A `DataFrame`. + public func xml(_ path: String) -> DataFrame { + self.source = "xml" + return load(path) + } + + /// Loads XML files and returns the result as a `DataFrame`. + /// - Parameter paths: Path strings + /// - Returns: A `DataFrame`. + public func xml(_ paths: String...) -> DataFrame { + self.source = "xml" + return load(paths) + } + /// Loads an ORC file and returns the result as a `DataFrame`. /// - Parameter path: A path string /// - Returns: A `DataFrame`. diff --git a/Sources/SparkConnect/DataFrameWriter.swift b/Sources/SparkConnect/DataFrameWriter.swift index ffb0183..6846df2 100644 --- a/Sources/SparkConnect/DataFrameWriter.swift +++ b/Sources/SparkConnect/DataFrameWriter.swift @@ -171,6 +171,14 @@ public actor DataFrameWriter: Sendable { return try await save(path) } + /// Saves the content of the `DataFrame` in XML format at the specified path. + /// - Parameter path: A path string + /// - Returns: A `DataFrame`. + public func xml(_ path: String) async throws { + self.source = "xml" + return try await save(path) + } + /// Saves the content of the `DataFrame` in ORC format at the specified path. /// - Parameter path: A path string /// - Returns: A `DataFrame`. diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift b/Tests/SparkConnectTests/DataFrameReaderTests.swift index 101d842..781282e 100644 --- a/Tests/SparkConnectTests/DataFrameReaderTests.swift +++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift @@ -45,6 +45,16 @@ struct DataFrameReaderTests { await spark.stop() } + @Test + func xml() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let path = "../examples/src/main/resources/people.xml" + #expect(try await spark.read.option("rowTag", "person").format("xml").load(path).count() == 3) + #expect(try await spark.read.option("rowTag", "person").xml(path).count() == 3) + #expect(try await spark.read.option("rowTag", "person").xml(path, path).count() == 6) + await spark.stop() + } + @Test func orc() async throws { let spark = try await SparkSession.builder.getOrCreate() diff --git a/Tests/SparkConnectTests/DataFrameWriterTests.swift b/Tests/SparkConnectTests/DataFrameWriterTests.swift index 9e9841a..d7fde78 100644 --- a/Tests/SparkConnectTests/DataFrameWriterTests.swift +++ b/Tests/SparkConnectTests/DataFrameWriterTests.swift @@ -43,6 +43,15 @@ struct DataFrameWriterTests { await spark.stop() } + @Test + func xml() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(2025).write.option("rowTag", "person").xml(tmpDir) + #expect(try await spark.read.option("rowTag", "person").xml(tmpDir).count() == 2025) + await spark.stop() + } + @Test func orc() async throws { let tmpDir = "/tmp/" + UUID().uuidString