Skip to content

Commit 86ce0af

Browse files
committed
Test
1 parent c4f254b commit 86ce0af

File tree

7 files changed

+122
-23
lines changed

7 files changed

+122
-23
lines changed

Sources/SparkConnect/ArrowReader.swift

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import FlatBuffers
1919
import Foundation
2020

2121
let FILEMARKER = "ARROW1"
22-
let CONTINUATIONMARKER = -1
22+
let CONTINUATIONMARKER = UInt32(0xFFFF_FFFF)
2323

2424
/// @nodoc
2525
public class ArrowReader { // swiftlint:disable:this type_body_length
@@ -240,7 +240,78 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
240240
return .success(RecordBatch(arrowSchema, columns: columns))
241241
}
242242

243-
public func fromStream( // swiftlint:disable:this function_body_length
243+
/*
244+
This is for reading the Arrow streaming format. The Arrow streaming format
245+
is slightly different from the Arrow File format as it doesn't contain a header
246+
and footer.
247+
*/
248+
public func readStreaming( // swiftlint:disable:this function_body_length
249+
_ fileData: Data,
250+
useUnalignedBuffers: Bool = false
251+
) -> Result<ArrowReaderResult, ArrowError> {
252+
let result = ArrowReaderResult()
253+
var offset: Int = 0
254+
var length = getUInt32(fileData, offset: offset)
255+
var streamData = fileData
256+
var schemaMessage: org_apache_arrow_flatbuf_Schema?
257+
while length != 0 {
258+
if length == CONTINUATIONMARKER {
259+
offset += Int(MemoryLayout<UInt32>.size)
260+
length = getUInt32(fileData, offset: offset)
261+
if length == 0 {
262+
return .success(result)
263+
}
264+
}
265+
266+
offset += Int(MemoryLayout<UInt32>.size)
267+
streamData = fileData[offset...]
268+
let dataBuffer = ByteBuffer(
269+
data: streamData,
270+
allowReadingUnalignedBuffers: true)
271+
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
272+
switch message.headerType {
273+
case .recordbatch:
274+
do {
275+
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
276+
offset += Int(message.bodyLength + Int64(length))
277+
let recordBatch = try loadRecordBatch(
278+
rbMessage,
279+
schema: schemaMessage!,
280+
arrowSchema: result.schema!,
281+
data: fileData,
282+
messageEndOffset: (message.bodyLength + Int64(length))
283+
).get()
284+
result.batches.append(recordBatch)
285+
length = getUInt32(fileData, offset: offset)
286+
} catch let error as ArrowError {
287+
return .failure(error)
288+
} catch {
289+
return .failure(.unknownError("Unexpected error: \(error)"))
290+
}
291+
case .schema:
292+
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
293+
let schemaResult = loadSchema(schemaMessage!)
294+
switch schemaResult {
295+
case .success(let schema):
296+
result.schema = schema
297+
case .failure(let error):
298+
return .failure(error)
299+
}
300+
offset += Int(message.bodyLength + Int64(length))
301+
length = getUInt32(fileData, offset: offset)
302+
default:
303+
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
304+
}
305+
}
306+
return .success(result)
307+
}
308+
309+
/*
310+
This is for reading the Arrow file format. The Arrow file format supports
311+
random accessing the data. The Arrow file format contains a header and
312+
footer around the Arrow streaming format.
313+
*/
314+
public func readFile( // swiftlint:disable:this function_body_length
244315
_ fileData: Data,
245316
useUnalignedBuffers: Bool = false
246317
) -> Result<ArrowReaderResult, ArrowError> {
@@ -266,7 +337,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
266337
for index in 0..<footer.recordBatchesCount {
267338
let recordBatch = footer.recordBatches(at: index)!
268339
var messageLength = fileData.withUnsafeBytes { rawBuffer in
269-
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
340+
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
270341
}
271342

272343
var messageOffset: Int64 = 1
@@ -275,7 +346,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
275346
messageLength = fileData.withUnsafeBytes { rawBuffer in
276347
rawBuffer.loadUnaligned(
277348
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
278-
as: Int32.self)
349+
as: UInt32.self)
279350
}
280351
}
281352

@@ -320,7 +391,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
320391
let markerLength = FILEMARKER.utf8.count
321392
let footerLengthEnd = Int(fileData.count - markerLength)
322393
let data = fileData[..<(footerLengthEnd)]
323-
return fromStream(data)
394+
return readFile(data)
324395
} catch {
325396
return .failure(.unknownError("Error loading file: \(error)"))
326397
}

Sources/SparkConnect/ArrowReaderHelper.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,10 @@ func validateFileData(_ data: Data) -> Bool {
312312
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
313313
return startString == FILEMARKER && endString == FILEMARKER
314314
}
315+
316+
func getUInt32(_ data: Data, offset: Int) -> UInt32 {
317+
let token = data.withUnsafeBytes { rawBuffer in
318+
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
319+
}
320+
return token
321+
}

Sources/SparkConnect/ArrowWriter.swift

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
132132
let startIndex = writer.count
133133
switch writeRecordBatch(batch: batch) {
134134
case .success(let rbResult):
135+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
135136
withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) }
136137
writer.append(rbResult.0)
137138
switch writeRecordBatchData(&writer, batch: batch) {
@@ -250,7 +251,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
250251
return .success(fbb.data)
251252
}
252253

253-
private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
254+
private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
254255
Bool, ArrowError
255256
> {
256257
var fbb: FlatBufferBuilder = FlatBufferBuilder()
@@ -284,9 +285,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
284285
return .success(true)
285286
}
286287

287-
public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
288+
public func toMemoryStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
289+
let writer: any DataWriter = InMemDataWriter()
290+
switch toMessage(info.schema) {
291+
case .success(let schemaData):
292+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
293+
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) { writer.append(Data($0)) }
294+
writer.append(schemaData)
295+
case .failure(let error):
296+
return .failure(error)
297+
}
298+
299+
for batch in info.batches {
300+
switch toMessage(batch) {
301+
case .success(let batchData):
302+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
303+
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) { writer.append(Data($0)) }
304+
writer.append(batchData[0])
305+
writer.append(batchData[1])
306+
case .failure(let error):
307+
return .failure(error)
308+
}
309+
}
310+
311+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
312+
withUnsafeBytes(of: UInt32(0).littleEndian) { writer.append(Data($0)) }
313+
if let memWriter = writer as? InMemDataWriter {
314+
return .success(memWriter.data)
315+
} else {
316+
return .failure(.invalid("Unable to cast writer"))
317+
}
318+
}
319+
320+
public func toFileStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
288321
var writer: any DataWriter = InMemDataWriter()
289-
switch writeStream(&writer, info: info) {
322+
switch writeFileStream(&writer, info: info) {
290323
case .success:
291324
if let memWriter = writer as? InMemDataWriter {
292325
return .success(memWriter.data)
@@ -313,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
313346

314347
var writer: any DataWriter = FileDataWriter(fileHandle)
315348
writer.append(FILEMARKER.data(using: .utf8)!)
316-
switch writeStream(&writer, info: info) {
349+
switch writeFileStream(&writer, info: info) {
317350
case .success:
318351
writer.append(FILEMARKER.data(using: .utf8)!)
319352
case .failure(let error):

Tests/SparkConnectTests/CatalogTests.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import Testing
2525
/// A test suite for `Catalog`
2626
@Suite(.serialized)
2727
struct CatalogTests {
28-
#if !os(Linux)
2928
@Test
3029
func currentCatalog() async throws {
3130
let spark = try await SparkSession.builder.getOrCreate()
@@ -300,7 +299,6 @@ struct CatalogTests {
300299
#expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false)
301300
await spark.stop()
302301
}
303-
#endif
304302

305303
@Test
306304
func cacheTable() async throws {

Tests/SparkConnectTests/DataFrameInternalTests.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import Testing
2525
@Suite(.serialized)
2626
struct DataFrameInternalTests {
2727

28-
#if !os(Linux)
2928
@Test
3029
func showString() async throws {
3130
let spark = try await SparkSession.builder.getOrCreate()
@@ -83,5 +82,4 @@ struct DataFrameInternalTests {
8382
""")
8483
await spark.stop()
8584
}
86-
#endif
8785
}

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ struct DataFrameTests {
324324
await spark.stop()
325325
}
326326

327-
#if !os(Linux)
328327
@Test
329328
func sort() async throws {
330329
let spark = try await SparkSession.builder.getOrCreate()
@@ -340,7 +339,6 @@ struct DataFrameTests {
340339
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected)
341340
await spark.stop()
342341
}
343-
#endif
344342

345343
@Test
346344
func table() async throws {
@@ -356,7 +354,6 @@ struct DataFrameTests {
356354
await spark.stop()
357355
}
358356

359-
#if !os(Linux)
360357
@Test
361358
func collect() async throws {
362359
let spark = try await SparkSession.builder.getOrCreate()
@@ -508,7 +505,7 @@ struct DataFrameTests {
508505
#expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected)
509506
await spark.stop()
510507
}
511-
508+
#if !os(Linux) // TODO: Enable this on linux
512509
@Test
513510
func lateralJoin() async throws {
514511
let spark = try await SparkSession.builder.getOrCreate()
@@ -530,7 +527,7 @@ struct DataFrameTests {
530527
}
531528
await spark.stop()
532529
}
533-
530+
#endif
534531
@Test
535532
func except() async throws {
536533
let spark = try await SparkSession.builder.getOrCreate()
@@ -759,7 +756,6 @@ struct DataFrameTests {
759756
])
760757
await spark.stop()
761758
}
762-
#endif
763759

764760
@Test
765761
func storageLevel() async throws {

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ struct SparkSessionTests {
8686
await spark.stop()
8787
}
8888

89-
#if !os(Linux)
9089
@Test
9190
func sql() async throws {
9291
let spark = try await SparkSession.builder.getOrCreate()
@@ -97,7 +96,6 @@ struct SparkSessionTests {
9796
}
9897
await spark.stop()
9998
}
100-
#endif
10199

102100
@Test
103101
func table() async throws {
@@ -114,10 +112,8 @@ struct SparkSessionTests {
114112
func time() async throws {
115113
let spark = try await SparkSession.builder.getOrCreate()
116114
#expect(try await spark.time(spark.range(1000).count) == 1000)
117-
#if !os(Linux)
118115
#expect(try await spark.time(spark.range(1).collect) == [Row(0)])
119116
try await spark.time(spark.range(10).show)
120-
#endif
121117
await spark.stop()
122118
}
123119

0 commit comments

Comments
 (0)