Skip to content

Commit 8893e88

Browse files
authored
GH-44910: [Swift] Fix IPC stream reader and writer impl (#45029)
### Rationale for this change Fixes IPC incorrect stream format issue. Changes have been tested with: 1. directions from apache/arrow-experiments#41 (comment) 2. generated file using generate.py from https://github.com/apache/arrow-experiments/tree/main/data/rand-many-types (removed currently unsupported Swift types) **This PR includes breaking changes to public APIs.** Writer and reader APIs have changed: Reader: fromStream -> fromFileStream Writer: toStream -> toFileStream * GitHub Issue: #44910 Authored-by: Alva Bandy <[email protected]> Signed-off-by: Sutou Kouhei <[email protected]>
1 parent c47f605 commit 8893e88

File tree

4 files changed

+184
-20
lines changed

4 files changed

+184
-20
lines changed

swift/Arrow/Sources/Arrow/ArrowReader.swift

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import FlatBuffers
1919
import Foundation
2020

2121
let FILEMARKER = "ARROW1"
22-
let CONTINUATIONMARKER = -1
22+
let CONTINUATIONMARKER = UInt32(0xFFFFFFFF)
2323

2424
public class ArrowReader { // swiftlint:disable:this type_body_length
2525
private class RecordBatchData {
@@ -216,7 +216,77 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
216216
return .success(RecordBatch(arrowSchema, columns: columns))
217217
}
218218

219-
public func fromStream( // swiftlint:disable:this function_body_length
219+
/*
220+
This is for reading the Arrow streaming format. The Arrow streaming format
221+
is slightly different from the Arrow File format as it doesn't contain a header
222+
and footer.
223+
*/
224+
public func readStreaming( // swiftlint:disable:this function_body_length
225+
_ input: Data,
226+
useUnalignedBuffers: Bool = false
227+
) -> Result<ArrowReaderResult, ArrowError> {
228+
let result = ArrowReaderResult()
229+
var offset: Int = 0
230+
var length = getUInt32(input, offset: offset)
231+
var streamData = input
232+
var schemaMessage: org_apache_arrow_flatbuf_Schema?
233+
while length != 0 {
234+
if length == CONTINUATIONMARKER {
235+
offset += Int(MemoryLayout<UInt32>.size)
236+
length = getUInt32(input, offset: offset)
237+
if length == 0 {
238+
return .success(result)
239+
}
240+
}
241+
242+
offset += Int(MemoryLayout<UInt32>.size)
243+
streamData = input[offset...]
244+
let dataBuffer = ByteBuffer(
245+
data: streamData,
246+
allowReadingUnalignedBuffers: true)
247+
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
248+
switch message.headerType {
249+
case .recordbatch:
250+
do {
251+
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
252+
let recordBatch = try loadRecordBatch(
253+
rbMessage,
254+
schema: schemaMessage!,
255+
arrowSchema: result.schema!,
256+
data: input,
257+
messageEndOffset: (Int64(offset) + Int64(length))).get()
258+
result.batches.append(recordBatch)
259+
offset += Int(message.bodyLength + Int64(length))
260+
length = getUInt32(input, offset: offset)
261+
} catch let error as ArrowError {
262+
return .failure(error)
263+
} catch {
264+
return .failure(.unknownError("Unexpected error: \(error)"))
265+
}
266+
case .schema:
267+
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
268+
let schemaResult = loadSchema(schemaMessage!)
269+
switch schemaResult {
270+
case .success(let schema):
271+
result.schema = schema
272+
case .failure(let error):
273+
return .failure(error)
274+
}
275+
offset += Int(message.bodyLength + Int64(length))
276+
length = getUInt32(input, offset: offset)
277+
default:
278+
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
279+
}
280+
}
281+
return .success(result)
282+
}
283+
284+
/*
285+
This is for reading the Arrow file format. The Arrow file format supports
286+
random accessing the data. The Arrow file format contains a header and
287+
footer around the Arrow streaming format.
288+
*/
289+
public func readFile( // swiftlint:disable:this function_body_length
220290
_ fileData: Data,
221291
useUnalignedBuffers: Bool = false
222292
) -> Result<ArrowReaderResult, ArrowError> {
@@ -242,7 +312,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
242312
for index in 0 ..< footer.recordBatchesCount {
243313
let recordBatch = footer.recordBatches(at: index)!
244314
var messageLength = fileData.withUnsafeBytes { rawBuffer in
245-
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
315+
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
246316
}
247317

248318
var messageOffset: Int64 = 1
@@ -251,7 +321,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
251321
messageLength = fileData.withUnsafeBytes { rawBuffer in
252322
rawBuffer.loadUnaligned(
253323
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
254-
as: Int32.self)
324+
as: UInt32.self)
255325
}
256326
}
257327

@@ -296,7 +366,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
296366
let markerLength = FILEMARKER.utf8.count
297367
let footerLengthEnd = Int(fileData.count - markerLength)
298368
let data = fileData[..<(footerLengthEnd)]
299-
return fromStream(data)
369+
return readFile(data)
300370
} catch {
301371
return .failure(.unknownError("Error loading file: \(error)"))
302372
}
@@ -340,10 +410,10 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
340410
} catch {
341411
return .failure(.unknownError("Unexpected error: \(error)"))
342412
}
343-
344413
default:
345414
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
346415
}
347416
}
348417

349418
}
419+
// swiftlint:disable:this file_length

swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,10 @@ func validateFileData(_ data: Data) -> Bool {
289289
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
290290
return startString == FILEMARKER && endString == FILEMARKER
291291
}
292+
293+
func getUInt32(_ data: Data, offset: Int) -> UInt32 {
294+
let token = data.withUnsafeBytes { rawBuffer in
295+
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
296+
}
297+
return token
298+
}

swift/Arrow/Sources/Arrow/ArrowWriter.swift

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
123123
let startIndex = writer.count
124124
switch writeRecordBatch(batch: batch) {
125125
case .success(let rbResult):
126+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
126127
withUnsafeBytes(of: rbResult.1.o.littleEndian) {writer.append(Data($0))}
127128
writer.append(rbResult.0)
128129
switch writeRecordBatchData(&writer, batch: batch) {
@@ -232,7 +233,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
232233
return .success(fbb.data)
233234
}
234235

235-
private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
236+
private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
236237
var fbb: FlatBufferBuilder = FlatBufferBuilder()
237238
switch writeSchema(&fbb, schema: info.schema) {
238239
case .success(let schemaOffset):
@@ -264,9 +265,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
264265
return .success(true)
265266
}
266267

267-
public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
268+
public func writeSteaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
269+
let writer: any DataWriter = InMemDataWriter()
270+
switch toMessage(info.schema) {
271+
case .success(let schemaData):
272+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
273+
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) {writer.append(Data($0))}
274+
writer.append(schemaData)
275+
case .failure(let error):
276+
return .failure(error)
277+
}
278+
279+
for batch in info.batches {
280+
switch toMessage(batch) {
281+
case .success(let batchData):
282+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
283+
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) {writer.append(Data($0))}
284+
writer.append(batchData[0])
285+
writer.append(batchData[1])
286+
case .failure(let error):
287+
return .failure(error)
288+
}
289+
}
290+
291+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
292+
withUnsafeBytes(of: UInt32(0).littleEndian) {writer.append(Data($0))}
293+
if let memWriter = writer as? InMemDataWriter {
294+
return .success(memWriter.data)
295+
} else {
296+
return .failure(.invalid("Unable to cast writer"))
297+
}
298+
}
299+
300+
public func writeFile(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
268301
var writer: any DataWriter = InMemDataWriter()
269-
switch writeStream(&writer, info: info) {
302+
switch writeFile(&writer, info: info) {
270303
case .success:
271304
if let memWriter = writer as? InMemDataWriter {
272305
return .success(memWriter.data)
@@ -293,7 +326,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
293326

294327
var writer: any DataWriter = FileDataWriter(fileHandle)
295328
writer.append(FILEMARKER.data(using: .utf8)!)
296-
switch writeStream(&writer, info: info) {
329+
switch writeFile(&writer, info: info) {
297330
case .success:
298331
writer.append(FILEMARKER.data(using: .utf8)!)
299332
case .failure(let error):

swift/Arrow/Tests/ArrowTests/IPCTests.swift

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,60 @@ func makeRecordBatch() throws -> RecordBatch {
118118
}
119119
}
120120

121+
final class IPCStreamReaderTests: XCTestCase {
122+
func testRBInMemoryToFromStream() throws {
123+
let schema = makeSchema()
124+
let recordBatch = try makeRecordBatch()
125+
let arrowWriter = ArrowWriter()
126+
let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch])
127+
switch arrowWriter.writeSteaming(writerInfo) {
128+
case .success(let writeData):
129+
let arrowReader = ArrowReader()
130+
switch arrowReader.readStreaming(writeData) {
131+
case .success(let result):
132+
let recordBatches = result.batches
133+
XCTAssertEqual(recordBatches.count, 1)
134+
for recordBatch in recordBatches {
135+
XCTAssertEqual(recordBatch.length, 4)
136+
XCTAssertEqual(recordBatch.columns.count, 5)
137+
XCTAssertEqual(recordBatch.schema.fields.count, 5)
138+
XCTAssertEqual(recordBatch.schema.fields[0].name, "col1")
139+
XCTAssertEqual(recordBatch.schema.fields[0].type.info, ArrowType.ArrowUInt8)
140+
XCTAssertEqual(recordBatch.schema.fields[1].name, "col2")
141+
XCTAssertEqual(recordBatch.schema.fields[1].type.info, ArrowType.ArrowString)
142+
XCTAssertEqual(recordBatch.schema.fields[2].name, "col3")
143+
XCTAssertEqual(recordBatch.schema.fields[2].type.info, ArrowType.ArrowDate32)
144+
XCTAssertEqual(recordBatch.schema.fields[3].name, "col4")
145+
XCTAssertEqual(recordBatch.schema.fields[3].type.info, ArrowType.ArrowInt32)
146+
XCTAssertEqual(recordBatch.schema.fields[4].name, "col5")
147+
XCTAssertEqual(recordBatch.schema.fields[4].type.info, ArrowType.ArrowFloat)
148+
let columns = recordBatch.columns
149+
XCTAssertEqual(columns[0].nullCount, 2)
150+
let dateVal =
151+
"\((columns[2].array as! AsString).asString(0))" // swiftlint:disable:this force_cast
152+
XCTAssertEqual(dateVal, "2014-09-10 00:00:00 +0000")
153+
let stringVal =
154+
"\((columns[1].array as! AsString).asString(1))" // swiftlint:disable:this force_cast
155+
XCTAssertEqual(stringVal, "test22")
156+
let uintVal =
157+
"\((columns[0].array as! AsString).asString(0))" // swiftlint:disable:this force_cast
158+
XCTAssertEqual(uintVal, "10")
159+
let stringVal2 =
160+
"\((columns[1].array as! AsString).asString(3))" // swiftlint:disable:this force_cast
161+
XCTAssertEqual(stringVal2, "test44")
162+
let uintVal2 =
163+
"\((columns[0].array as! AsString).asString(3))" // swiftlint:disable:this force_cast
164+
XCTAssertEqual(uintVal2, "44")
165+
}
166+
case.failure(let error):
167+
throw error
168+
}
169+
case .failure(let error):
170+
throw error
171+
}
172+
}
173+
}
174+
121175
final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body_length
122176
func testFileReader_struct() throws {
123177
let fileURL = currentDirectory().appendingPathComponent("../../testdata_struct.arrow")
@@ -204,10 +258,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body
204258
let arrowWriter = ArrowWriter()
205259
// write data from file to a stream
206260
let writerInfo = ArrowWriter.Info(.recordbatch, schema: fileRBs[0].schema, batches: fileRBs)
207-
switch arrowWriter.toStream(writerInfo) {
261+
switch arrowWriter.writeFile(writerInfo) {
208262
case .success(let writeData):
209263
// read stream back into recordbatches
210-
try checkBoolRecordBatch(arrowReader.fromStream(writeData))
264+
try checkBoolRecordBatch(arrowReader.readFile(writeData))
211265
case .failure(let error):
212266
throw error
213267
}
@@ -227,10 +281,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body
227281
let recordBatch = try makeRecordBatch()
228282
let arrowWriter = ArrowWriter()
229283
let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch])
230-
switch arrowWriter.toStream(writerInfo) {
284+
switch arrowWriter.writeFile(writerInfo) {
231285
case .success(let writeData):
232286
let arrowReader = ArrowReader()
233-
switch arrowReader.fromStream(writeData) {
287+
switch arrowReader.readFile(writeData) {
234288
case .success(let result):
235289
let recordBatches = result.batches
236290
XCTAssertEqual(recordBatches.count, 1)
@@ -279,10 +333,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body
279333
let schema = makeSchema()
280334
let arrowWriter = ArrowWriter()
281335
let writerInfo = ArrowWriter.Info(.schema, schema: schema)
282-
switch arrowWriter.toStream(writerInfo) {
336+
switch arrowWriter.writeFile(writerInfo) {
283337
case .success(let writeData):
284338
let arrowReader = ArrowReader()
285-
switch arrowReader.fromStream(writeData) {
339+
switch arrowReader.readFile(writeData) {
286340
case .success(let result):
287341
XCTAssertNotNil(result.schema)
288342
let schema = result.schema!
@@ -362,10 +416,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body
362416
let dataset = try makeBinaryDataset()
363417
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
364418
let arrowWriter = ArrowWriter()
365-
switch arrowWriter.toStream(writerInfo) {
419+
switch arrowWriter.writeFile(writerInfo) {
366420
case .success(let writeData):
367421
let arrowReader = ArrowReader()
368-
switch arrowReader.fromStream(writeData) {
422+
switch arrowReader.readFile(writeData) {
369423
case .success(let result):
370424
XCTAssertNotNil(result.schema)
371425
let schema = result.schema!
@@ -391,10 +445,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body
391445
let dataset = try makeTimeDataset()
392446
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
393447
let arrowWriter = ArrowWriter()
394-
switch arrowWriter.toStream(writerInfo) {
448+
switch arrowWriter.writeFile(writerInfo) {
395449
case .success(let writeData):
396450
let arrowReader = ArrowReader()
397-
switch arrowReader.fromStream(writeData) {
451+
switch arrowReader.readFile(writeData) {
398452
case .success(let result):
399453
XCTAssertNotNil(result.schema)
400454
let schema = result.schema!

0 commit comments

Comments
 (0)