Skip to content

Commit 0e55dd1

Browse files
committed
Test
1 parent e1e8a32 commit 0e55dd1

File tree

7 files changed

+122
-23
lines changed

7 files changed

+122
-23
lines changed

Sources/SparkConnect/ArrowReader.swift

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import FlatBuffers
1919
import Foundation
2020

2121
let FILEMARKER = "ARROW1"
22-
let CONTINUATIONMARKER = -1
22+
let CONTINUATIONMARKER = UInt32(0xFFFF_FFFF)
2323

2424
/// @nodoc
2525
public class ArrowReader { // swiftlint:disable:this type_body_length
@@ -240,7 +240,78 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
240240
return .success(RecordBatch(arrowSchema, columns: columns))
241241
}
242242

243-
public func fromStream( // swiftlint:disable:this function_body_length
243+
/*
244+
This is for reading the Arrow streaming format. The Arrow streaming format
245+
is slightly different from the Arrow File format as it doesn't contain a header
246+
and footer.
247+
*/
248+
public func readStreaming( // swiftlint:disable:this function_body_length
249+
_ fileData: Data,
250+
useUnalignedBuffers: Bool = false
251+
) -> Result<ArrowReaderResult, ArrowError> {
252+
let result = ArrowReaderResult()
253+
var offset: Int = 0
254+
var length = getUInt32(fileData, offset: offset)
255+
var streamData = fileData
256+
var schemaMessage: org_apache_arrow_flatbuf_Schema?
257+
while length != 0 {
258+
if length == CONTINUATIONMARKER {
259+
offset += Int(MemoryLayout<UInt32>.size)
260+
length = getUInt32(fileData, offset: offset)
261+
if length == 0 {
262+
return .success(result)
263+
}
264+
}
265+
266+
offset += Int(MemoryLayout<UInt32>.size)
267+
streamData = fileData[offset...]
268+
let dataBuffer = ByteBuffer(
269+
data: streamData,
270+
allowReadingUnalignedBuffers: true)
271+
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
272+
switch message.headerType {
273+
case .recordbatch:
274+
do {
275+
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
276+
offset += Int(message.bodyLength + Int64(length))
277+
let recordBatch = try loadRecordBatch(
278+
rbMessage,
279+
schema: schemaMessage!,
280+
arrowSchema: result.schema!,
281+
data: fileData,
282+
messageEndOffset: (message.bodyLength + Int64(length))
283+
).get()
284+
result.batches.append(recordBatch)
285+
length = getUInt32(fileData, offset: offset)
286+
} catch let error as ArrowError {
287+
return .failure(error)
288+
} catch {
289+
return .failure(.unknownError("Unexpected error: \(error)"))
290+
}
291+
case .schema:
292+
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
293+
let schemaResult = loadSchema(schemaMessage!)
294+
switch schemaResult {
295+
case .success(let schema):
296+
result.schema = schema
297+
case .failure(let error):
298+
return .failure(error)
299+
}
300+
offset += Int(message.bodyLength + Int64(length))
301+
length = getUInt32(fileData, offset: offset)
302+
default:
303+
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
304+
}
305+
}
306+
return .success(result)
307+
}
308+
309+
/*
310+
This is for reading the Arrow file format. The Arrow file format supports
311+
random accessing the data. The Arrow file format contains a header and
312+
footer around the Arrow streaming format.
313+
*/
314+
public func readFile( // swiftlint:disable:this function_body_length
244315
_ fileData: Data,
245316
useUnalignedBuffers: Bool = false
246317
) -> Result<ArrowReaderResult, ArrowError> {
@@ -266,7 +337,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
266337
for index in 0..<footer.recordBatchesCount {
267338
let recordBatch = footer.recordBatches(at: index)!
268339
var messageLength = fileData.withUnsafeBytes { rawBuffer in
269-
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
340+
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
270341
}
271342

272343
var messageOffset: Int64 = 1
@@ -275,7 +346,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
275346
messageLength = fileData.withUnsafeBytes { rawBuffer in
276347
rawBuffer.loadUnaligned(
277348
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
278-
as: Int32.self)
349+
as: UInt32.self)
279350
}
280351
}
281352

@@ -320,7 +391,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
320391
let markerLength = FILEMARKER.utf8.count
321392
let footerLengthEnd = Int(fileData.count - markerLength)
322393
let data = fileData[..<(footerLengthEnd)]
323-
return fromStream(data)
394+
return readFile(data)
324395
} catch {
325396
return .failure(.unknownError("Error loading file: \(error)"))
326397
}

Sources/SparkConnect/ArrowReaderHelper.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,10 @@ func validateFileData(_ data: Data) -> Bool {
312312
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
313313
return startString == FILEMARKER && endString == FILEMARKER
314314
}
315+
316+
func getUInt32(_ data: Data, offset: Int) -> UInt32 {
317+
let token = data.withUnsafeBytes { rawBuffer in
318+
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
319+
}
320+
return token
321+
}

Sources/SparkConnect/ArrowWriter.swift

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
132132
let startIndex = writer.count
133133
switch writeRecordBatch(batch: batch) {
134134
case .success(let rbResult):
135+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
135136
withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) }
136137
writer.append(rbResult.0)
137138
switch writeRecordBatchData(&writer, batch: batch) {
@@ -250,7 +251,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
250251
return .success(fbb.data)
251252
}
252253

253-
private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
254+
private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
254255
Bool, ArrowError
255256
> {
256257
var fbb: FlatBufferBuilder = FlatBufferBuilder()
@@ -284,9 +285,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
284285
return .success(true)
285286
}
286287

287-
public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
288+
public func toMemoryStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
289+
let writer: any DataWriter = InMemDataWriter()
290+
switch toMessage(info.schema) {
291+
case .success(let schemaData):
292+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
293+
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) { writer.append(Data($0)) }
294+
writer.append(schemaData)
295+
case .failure(let error):
296+
return .failure(error)
297+
}
298+
299+
for batch in info.batches {
300+
switch toMessage(batch) {
301+
case .success(let batchData):
302+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
303+
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) { writer.append(Data($0)) }
304+
writer.append(batchData[0])
305+
writer.append(batchData[1])
306+
case .failure(let error):
307+
return .failure(error)
308+
}
309+
}
310+
311+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
312+
withUnsafeBytes(of: UInt32(0).littleEndian) { writer.append(Data($0)) }
313+
if let memWriter = writer as? InMemDataWriter {
314+
return .success(memWriter.data)
315+
} else {
316+
return .failure(.invalid("Unable to cast writer"))
317+
}
318+
}
319+
320+
public func toFileStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
288321
var writer: any DataWriter = InMemDataWriter()
289-
switch writeStream(&writer, info: info) {
322+
switch writeFileStream(&writer, info: info) {
290323
case .success:
291324
if let memWriter = writer as? InMemDataWriter {
292325
return .success(memWriter.data)
@@ -313,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
313346

314347
var writer: any DataWriter = FileDataWriter(fileHandle)
315348
writer.append(FILEMARKER.data(using: .utf8)!)
316-
switch writeStream(&writer, info: info) {
349+
switch writeFileStream(&writer, info: info) {
317350
case .success:
318351
writer.append(FILEMARKER.data(using: .utf8)!)
319352
case .failure(let error):

Tests/SparkConnectTests/CatalogTests.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import Testing
2424

2525
/// A test suite for `Catalog`
2626
struct CatalogTests {
27-
#if !os(Linux)
2827
@Test
2928
func currentCatalog() async throws {
3029
let spark = try await SparkSession.builder.getOrCreate()
@@ -299,7 +298,6 @@ struct CatalogTests {
299298
#expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false)
300299
await spark.stop()
301300
}
302-
#endif
303301

304302
@Test
305303
func cacheTable() async throws {

Tests/SparkConnectTests/DataFrameInternalTests.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import Testing
2424
/// A test suite for `DataFrame` internal APIs
2525
struct DataFrameInternalTests {
2626

27-
#if !os(Linux)
2827
@Test
2928
func showString() async throws {
3029
let spark = try await SparkSession.builder.getOrCreate()
@@ -82,5 +81,4 @@ struct DataFrameInternalTests {
8281
""")
8382
await spark.stop()
8483
}
85-
#endif
8684
}

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,6 @@ struct DataFrameTests {
323323
await spark.stop()
324324
}
325325

326-
#if !os(Linux)
327326
@Test
328327
func sort() async throws {
329328
let spark = try await SparkSession.builder.getOrCreate()
@@ -339,7 +338,6 @@ struct DataFrameTests {
339338
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected)
340339
await spark.stop()
341340
}
342-
#endif
343341

344342
@Test
345343
func table() async throws {
@@ -355,7 +353,6 @@ struct DataFrameTests {
355353
await spark.stop()
356354
}
357355

358-
#if !os(Linux)
359356
@Test
360357
func collect() async throws {
361358
let spark = try await SparkSession.builder.getOrCreate()
@@ -507,7 +504,7 @@ struct DataFrameTests {
507504
#expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected)
508505
await spark.stop()
509506
}
510-
507+
#if !os(Linux) // TODO: Enable this on linux
511508
@Test
512509
func lateralJoin() async throws {
513510
let spark = try await SparkSession.builder.getOrCreate()
@@ -529,7 +526,7 @@ struct DataFrameTests {
529526
}
530527
await spark.stop()
531528
}
532-
529+
#endif
533530
@Test
534531
func except() async throws {
535532
let spark = try await SparkSession.builder.getOrCreate()
@@ -758,7 +755,6 @@ struct DataFrameTests {
758755
])
759756
await spark.stop()
760757
}
761-
#endif
762758

763759
@Test
764760
func storageLevel() async throws {

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ struct SparkSessionTests {
8585
await spark.stop()
8686
}
8787

88-
#if !os(Linux)
8988
@Test
9089
func sql() async throws {
9190
let spark = try await SparkSession.builder.getOrCreate()
@@ -96,7 +95,6 @@ struct SparkSessionTests {
9695
}
9796
await spark.stop()
9897
}
99-
#endif
10098

10199
@Test
102100
func table() async throws {
@@ -113,10 +111,8 @@ struct SparkSessionTests {
113111
func time() async throws {
114112
let spark = try await SparkSession.builder.getOrCreate()
115113
#expect(try await spark.time(spark.range(1000).count) == 1000)
116-
#if !os(Linux)
117114
#expect(try await spark.time(spark.range(1).collect) == [Row(0)])
118115
try await spark.time(spark.range(10).show)
119-
#endif
120116
await spark.stop()
121117
}
122118

0 commit comments

Comments
 (0)