diff --git a/Sources/SparkConnect/ArrowReader.swift b/Sources/SparkConnect/ArrowReader.swift index de7af4f..6f7b81a 100644 --- a/Sources/SparkConnect/ArrowReader.swift +++ b/Sources/SparkConnect/ArrowReader.swift @@ -19,7 +19,7 @@ import FlatBuffers import Foundation let FILEMARKER = "ARROW1" -let CONTINUATIONMARKER = -1 +let CONTINUATIONMARKER = UInt32(0xFFFF_FFFF) /// @nodoc public class ArrowReader { // swiftlint:disable:this type_body_length @@ -240,7 +240,78 @@ public class ArrowReader { // swiftlint:disable:this type_body_length return .success(RecordBatch(arrowSchema, columns: columns)) } - public func fromStream( // swiftlint:disable:this function_body_length + /* + This is for reading the Arrow streaming format. The Arrow streaming format + is slightly different from the Arrow File format as it doesn't contain a header + and footer. + */ + public func readStreaming( // swiftlint:disable:this function_body_length + _ fileData: Data, + useUnalignedBuffers: Bool = false + ) -> Result { + let result = ArrowReaderResult() + var offset: Int = 0 + var length = getUInt32(fileData, offset: offset) + var streamData = fileData + var schemaMessage: org_apache_arrow_flatbuf_Schema? + while length != 0 { + if length == CONTINUATIONMARKER { + offset += Int(MemoryLayout.size) + length = getUInt32(fileData, offset: offset) + if length == 0 { + return .success(result) + } + } + + offset += Int(MemoryLayout.size) + streamData = fileData[offset...] + let dataBuffer = ByteBuffer( + data: streamData, + allowReadingUnalignedBuffers: true) + let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer) + switch message.headerType { + case .recordbatch: + do { + let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)! + offset += Int(message.bodyLength + Int64(length)) + let recordBatch = try loadRecordBatch( + rbMessage, + schema: schemaMessage!, + arrowSchema: result.schema!, + data: fileData, + messageEndOffset: (message.bodyLength + Int64(length)) + ).get() + result.batches.append(recordBatch) + length = getUInt32(fileData, offset: offset) + } catch let error as ArrowError { + return .failure(error) + } catch { + return .failure(.unknownError("Unexpected error: \(error)")) + } + case .schema: + schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)! + let schemaResult = loadSchema(schemaMessage!) + switch schemaResult { + case .success(let schema): + result.schema = schema + case .failure(let error): + return .failure(error) + } + offset += Int(message.bodyLength + Int64(length)) + length = getUInt32(fileData, offset: offset) + default: + return .failure(.unknownError("Unhandled header type: \(message.headerType)")) + } + } + return .success(result) + } + + /* + This is for reading the Arrow file format. The Arrow file format supports + random accessing the data. The Arrow file format contains a header and + footer around the Arrow streaming format. + */ + public func readFile( // swiftlint:disable:this function_body_length _ fileData: Data, useUnalignedBuffers: Bool = false ) -> Result { @@ -266,7 +337,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length for index in 0...size)), - as: Int32.self) + as: UInt32.self) } } @@ -320,7 +391,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length let markerLength = FILEMARKER.utf8.count let footerLengthEnd = Int(fileData.count - markerLength) let data = fileData[..<(footerLengthEnd)] - return fromStream(data) + return readFile(data) } catch { return .failure(.unknownError("Error loading file: \(error)")) } diff --git a/Sources/SparkConnect/ArrowReaderHelper.swift b/Sources/SparkConnect/ArrowReaderHelper.swift index c0bd55b..baa4e93 100644 --- a/Sources/SparkConnect/ArrowReaderHelper.swift +++ b/Sources/SparkConnect/ArrowReaderHelper.swift @@ -312,3 +312,10 @@ func validateFileData(_ data: Data) -> Bool { let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self) return startString == FILEMARKER && endString == FILEMARKER } + +func getUInt32(_ data: Data, offset: Int) -> UInt32 { + let token = data.withUnsafeBytes { rawBuffer in + rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self) + } + return token +} diff --git a/Sources/SparkConnect/ArrowWriter.swift b/Sources/SparkConnect/ArrowWriter.swift index 4b644cf..55c9524 100644 --- a/Sources/SparkConnect/ArrowWriter.swift +++ b/Sources/SparkConnect/ArrowWriter.swift @@ -132,6 +132,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length let startIndex = writer.count switch writeRecordBatch(batch: batch) { case .success(let rbResult): + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) } withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) } writer.append(rbResult.0) switch writeRecordBatchData(&writer, batch: batch) { @@ -250,7 +251,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(fbb.data) } - private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result< + private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result< Bool, ArrowError > { var fbb: FlatBufferBuilder = FlatBufferBuilder() @@ -284,9 +285,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(true) } - public func toStream(_ info: ArrowWriter.Info) -> Result { + public func toMemoryStream(_ info: ArrowWriter.Info) -> Result { + let writer: any DataWriter = InMemDataWriter() + switch toMessage(info.schema) { + case .success(let schemaData): + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) } + withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) { writer.append(Data($0)) } + writer.append(schemaData) + case .failure(let error): + return .failure(error) + } + + for batch in info.batches { + switch toMessage(batch) { + case .success(let batchData): + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) } + withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) { writer.append(Data($0)) } + writer.append(batchData[0]) + writer.append(batchData[1]) + case .failure(let error): + return .failure(error) + } + } + + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) } + withUnsafeBytes(of: UInt32(0).littleEndian) { writer.append(Data($0)) } + if let memWriter = writer as? InMemDataWriter { + return .success(memWriter.data) + } else { + return .failure(.invalid("Unable to cast writer")) + } + } + + public func toFileStream(_ info: ArrowWriter.Info) -> Result { var writer: any DataWriter = InMemDataWriter() - switch writeStream(&writer, info: info) { + switch writeFileStream(&writer, info: info) { case .success: if let memWriter = writer as? InMemDataWriter { return .success(memWriter.data) @@ -313,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length var writer: any DataWriter = FileDataWriter(fileHandle) writer.append(FILEMARKER.data(using: .utf8)!) - switch writeStream(&writer, info: info) { + switch writeFileStream(&writer, info: info) { case .success: writer.append(FILEMARKER.data(using: .utf8)!) case .failure(let error): diff --git a/Tests/SparkConnectTests/CatalogTests.swift b/Tests/SparkConnectTests/CatalogTests.swift index 053daf9..d8cca1b 100644 --- a/Tests/SparkConnectTests/CatalogTests.swift +++ b/Tests/SparkConnectTests/CatalogTests.swift @@ -25,7 +25,6 @@ import Testing /// A test suite for `Catalog` @Suite(.serialized) struct CatalogTests { -#if !os(Linux) @Test func currentCatalog() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -300,7 +299,6 @@ struct CatalogTests { #expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false) await spark.stop() } -#endif @Test func cacheTable() async throws { diff --git a/Tests/SparkConnectTests/DataFrameInternalTests.swift b/Tests/SparkConnectTests/DataFrameInternalTests.swift index 96e8fc2..1b79419 100644 --- a/Tests/SparkConnectTests/DataFrameInternalTests.swift +++ b/Tests/SparkConnectTests/DataFrameInternalTests.swift @@ -25,7 +25,6 @@ import Testing @Suite(.serialized) struct DataFrameInternalTests { -#if !os(Linux) @Test func showString() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -83,5 +82,4 @@ struct DataFrameInternalTests { """) await spark.stop() } -#endif } diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 693f371..6bcef66 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -324,7 +324,6 @@ struct DataFrameTests { await spark.stop() } -#if !os(Linux) @Test func sort() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -340,7 +339,6 @@ struct DataFrameTests { #expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected) await spark.stop() } -#endif @Test func table() async throws { @@ -356,7 +354,6 @@ struct DataFrameTests { await spark.stop() } -#if !os(Linux) @Test func collect() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -508,7 +505,7 @@ struct DataFrameTests { #expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) await spark.stop() } - +#if !os(Linux) // TODO: Enable this on linux @Test func lateralJoin() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -530,7 +527,7 @@ struct DataFrameTests { } await spark.stop() } - +#endif @Test func except() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -759,7 +756,6 @@ struct DataFrameTests { ]) await spark.stop() } -#endif @Test func storageLevel() async throws { diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 9097e99..e65a1cf 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -86,7 +86,6 @@ struct SparkSessionTests { await spark.stop() } -#if !os(Linux) @Test func sql() async throws { let spark = try await SparkSession.builder.getOrCreate() @@ -97,7 +96,6 @@ struct SparkSessionTests { } await spark.stop() } -#endif @Test func table() async throws { @@ -114,10 +112,8 @@ struct SparkSessionTests { func time() async throws { let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.time(spark.range(1000).count) == 1000) -#if !os(Linux) #expect(try await spark.time(spark.range(1).collect) == [Row(0)]) try await spark.time(spark.range(10).show) -#endif await spark.stop() }