Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions Sources/SparkConnect/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import FlatBuffers
import Foundation

let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1
let CONTINUATIONMARKER = UInt32(0xFFFF_FFFF)

/// @nodoc
public class ArrowReader { // swiftlint:disable:this type_body_length
Expand Down Expand Up @@ -240,7 +240,78 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
return .success(RecordBatch(arrowSchema, columns: columns))
}

public func fromStream( // swiftlint:disable:this function_body_length
/*
This is for reading the Arrow streaming format. The Arrow streaming format
is slightly different from the Arrow File format as it doesn't contain a header
and footer.
*/
public func readStreaming( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
let result = ArrowReaderResult()
var offset: Int = 0
var length = getUInt32(fileData, offset: offset)
var streamData = fileData
var schemaMessage: org_apache_arrow_flatbuf_Schema?
while length != 0 {
if length == CONTINUATIONMARKER {
offset += Int(MemoryLayout<UInt32>.size)
length = getUInt32(fileData, offset: offset)
if length == 0 {
return .success(result)
}
}

offset += Int(MemoryLayout<UInt32>.size)
streamData = fileData[offset...]
let dataBuffer = ByteBuffer(
data: streamData,
allowReadingUnalignedBuffers: true)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
switch message.headerType {
case .recordbatch:
do {
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
offset += Int(message.bodyLength + Int64(length))
let recordBatch = try loadRecordBatch(
rbMessage,
schema: schemaMessage!,
arrowSchema: result.schema!,
data: fileData,
messageEndOffset: (message.bodyLength + Int64(length))
).get()
result.batches.append(recordBatch)
length = getUInt32(fileData, offset: offset)
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}
case .schema:
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
let schemaResult = loadSchema(schemaMessage!)
switch schemaResult {
case .success(let schema):
result.schema = schema
case .failure(let error):
return .failure(error)
}
offset += Int(message.bodyLength + Int64(length))
length = getUInt32(fileData, offset: offset)
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}
return .success(result)
}

/*
This is for reading the Arrow file format. The Arrow file format supports
random accessing the data. The Arrow file format contains a header and
footer around the Arrow streaming format.
*/
public func readFile( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
Expand All @@ -266,7 +337,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
for index in 0..<footer.recordBatchesCount {
let recordBatch = footer.recordBatches(at: index)!
var messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
}

var messageOffset: Int64 = 1
Expand All @@ -275,7 +346,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
as: Int32.self)
as: UInt32.self)
}
}

Expand Down Expand Up @@ -320,7 +391,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
let markerLength = FILEMARKER.utf8.count
let footerLengthEnd = Int(fileData.count - markerLength)
let data = fileData[..<(footerLengthEnd)]
return fromStream(data)
return readFile(data)
} catch {
return .failure(.unknownError("Error loading file: \(error)"))
}
Expand Down
7 changes: 7 additions & 0 deletions Sources/SparkConnect/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,10 @@ func validateFileData(_ data: Data) -> Bool {
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
return startString == FILEMARKER && endString == FILEMARKER
}

func getUInt32(_ data: Data, offset: Int) -> UInt32 {
let token = data.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
}
return token
}
41 changes: 37 additions & 4 deletions Sources/SparkConnect/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
let startIndex = writer.count
switch writeRecordBatch(batch: batch) {
case .success(let rbResult):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) }
writer.append(rbResult.0)
switch writeRecordBatchData(&writer, batch: batch) {
Expand Down Expand Up @@ -250,7 +251,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(fbb.data)
}

private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
Bool, ArrowError
> {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
Expand Down Expand Up @@ -284,9 +285,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(true)
}

public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
public func toMemoryStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
let writer: any DataWriter = InMemDataWriter()
switch toMessage(info.schema) {
case .success(let schemaData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) { writer.append(Data($0)) }
writer.append(schemaData)
case .failure(let error):
return .failure(error)
}

for batch in info.batches {
switch toMessage(batch) {
case .success(let batchData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) { writer.append(Data($0)) }
writer.append(batchData[0])
writer.append(batchData[1])
case .failure(let error):
return .failure(error)
}
}

withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: UInt32(0).littleEndian) { writer.append(Data($0)) }
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
} else {
return .failure(.invalid("Unable to cast writer"))
}
}

public func toFileStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeStream(&writer, info: info) {
switch writeFileStream(&writer, info: info) {
case .success:
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
Expand All @@ -313,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length

var writer: any DataWriter = FileDataWriter(fileHandle)
writer.append(FILEMARKER.data(using: .utf8)!)
switch writeStream(&writer, info: info) {
switch writeFileStream(&writer, info: info) {
case .success:
writer.append(FILEMARKER.data(using: .utf8)!)
case .failure(let error):
Expand Down
2 changes: 0 additions & 2 deletions Tests/SparkConnectTests/CatalogTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import Testing
/// A test suite for `Catalog`
@Suite(.serialized)
struct CatalogTests {
#if !os(Linux)
@Test
func currentCatalog() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down Expand Up @@ -300,7 +299,6 @@ struct CatalogTests {
#expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false)
await spark.stop()
}
#endif

@Test
func cacheTable() async throws {
Expand Down
2 changes: 0 additions & 2 deletions Tests/SparkConnectTests/DataFrameInternalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import Testing
@Suite(.serialized)
struct DataFrameInternalTests {

#if !os(Linux)
@Test
func showString() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down Expand Up @@ -83,5 +82,4 @@ struct DataFrameInternalTests {
""")
await spark.stop()
}
#endif
}
8 changes: 2 additions & 6 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ struct DataFrameTests {
await spark.stop()
}

#if !os(Linux)
@Test
func sort() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand All @@ -340,7 +339,6 @@ struct DataFrameTests {
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected)
await spark.stop()
}
#endif

@Test
func table() async throws {
Expand All @@ -356,7 +354,6 @@ struct DataFrameTests {
await spark.stop()
}

#if !os(Linux)
@Test
func collect() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down Expand Up @@ -508,7 +505,7 @@ struct DataFrameTests {
#expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected)
await spark.stop()
}

#if !os(Linux) // TODO: Enable this on linux
@Test
func lateralJoin() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand All @@ -530,7 +527,7 @@ struct DataFrameTests {
}
await spark.stop()
}

#endif
@Test
func except() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down Expand Up @@ -759,7 +756,6 @@ struct DataFrameTests {
])
await spark.stop()
}
#endif

@Test
func storageLevel() async throws {
Expand Down
4 changes: 0 additions & 4 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ struct SparkSessionTests {
await spark.stop()
}

#if !os(Linux)
@Test
func sql() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand All @@ -97,7 +96,6 @@ struct SparkSessionTests {
}
await spark.stop()
}
#endif

@Test
func table() async throws {
Expand All @@ -114,10 +112,8 @@ struct SparkSessionTests {
func time() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.time(spark.range(1000).count) == 1000)
#if !os(Linux)
#expect(try await spark.time(spark.range(1).collect) == [Row(0)])
try await spark.time(spark.range(10).show)
#endif
await spark.stop()
}

Expand Down
Loading