Skip to content

Commit 422dd14

Browse files
committed
[SPARK-52340] Update ArrowWriter(Helper)? and ProtoUtil with GH-43170
### What changes were proposed in this pull request? This PR aims to update `ArrowWriter`, `ArrowWriterHelper`, and `ProtoUtil` to apply the upstream Apache Arrow change, GH-43170. ### Why are the changes needed? We need to keep syncing with the upstream in order to use Apache Arrow Swift library when it's released. This is a part of preparation to remove these files from this repository in the end. ### Does this PR introduce _any_ user-facing change? No behavior change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #182 from dongjoon-hyun/SPARK-52340. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 2d9e112 commit 422dd14

File tree

3 files changed

+135
-62
lines changed

3 files changed

+135
-62
lines changed

Sources/SparkConnect/ArrowWriter.swift

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
import FlatBuffers
1919
import Foundation
2020

21-
/// @nodoc
2221
public protocol DataWriter {
2322
var count: Int { get }
2423
func append(_ data: Data)
2524
}
2625

27-
/// @nodoc
2826
public class ArrowWriter { // swiftlint:disable:this type_body_length
2927
public class InMemDataWriter: DataWriter {
3028
public private(set) var data: Data
@@ -77,11 +75,30 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
7775
private func writeField(_ fbb: inout FlatBufferBuilder, field: ArrowField) -> Result<
7876
Offset, ArrowError
7977
> {
78+
var fieldsOffset: Offset?
79+
if let nestedField = field.type as? ArrowNestedType {
80+
var offsets = [Offset]()
81+
for field in nestedField.fields {
82+
switch writeField(&fbb, field: field) {
83+
case .success(let offset):
84+
offsets.append(offset)
85+
case .failure(let error):
86+
return .failure(error)
87+
}
88+
}
89+
90+
fieldsOffset = fbb.createVector(ofOffsets: offsets)
91+
}
92+
8093
let nameOffset = fbb.create(string: field.name)
8194
let fieldTypeOffsetResult = toFBType(&fbb, arrowType: field.type)
8295
let startOffset = org_apache_arrow_flatbuf_Field.startField(&fbb)
8396
org_apache_arrow_flatbuf_Field.add(name: nameOffset, &fbb)
8497
org_apache_arrow_flatbuf_Field.add(nullable: field.isNullable, &fbb)
98+
if let childrenOffset = fieldsOffset {
99+
org_apache_arrow_flatbuf_Field.addVectorOf(children: childrenOffset, &fbb)
100+
}
101+
85102
switch toFBTypeEnum(field.type) {
86103
case .success(let type):
87104
org_apache_arrow_flatbuf_Field.add(typeType: type, &fbb)
@@ -109,7 +126,6 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
109126
case .failure(let error):
110127
return .failure(error)
111128
}
112-
113129
}
114130

115131
let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets)
@@ -135,7 +151,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
135151
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
136152
withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) }
137153
writer.append(rbResult.0)
138-
switch writeRecordBatchData(&writer, batch: batch) {
154+
switch writeRecordBatchData(&writer, fields: batch.schema.fields, columns: batch.columns) {
139155
case .success:
140156
rbBlocks.append(
141157
org_apache_arrow_flatbuf_Block(
@@ -153,40 +169,69 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
153169
return .success(rbBlocks)
154170
}
155171

156-
private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
157-
let schema = batch.schema
158-
var fbb = FlatBufferBuilder()
159-
160-
// write out field nodes
161-
var fieldNodeOffsets = [Offset]()
162-
fbb.startVector(
163-
schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
164-
for index in (0..<schema.fields.count).reversed() {
165-
let column = batch.column(index)
172+
private func writeFieldNodes(
173+
_ fields: [ArrowField], columns: [ArrowArrayHolder], offsets: inout [Offset],
174+
fbb: inout FlatBufferBuilder
175+
) {
176+
for index in (0..<fields.count).reversed() {
177+
let column = columns[index]
166178
let fieldNode =
167179
org_apache_arrow_flatbuf_FieldNode(
168180
length: Int64(column.length),
169181
nullCount: Int64(column.nullCount))
170-
fieldNodeOffsets.append(fbb.create(struct: fieldNode))
182+
offsets.append(fbb.create(struct: fieldNode))
183+
if let nestedType = column.type as? ArrowNestedType {
184+
let structArray = column.array as? StructArray
185+
writeFieldNodes(
186+
nestedType.fields, columns: structArray!.arrowFields!, offsets: &offsets, fbb: &fbb)
187+
}
171188
}
189+
}
172190

173-
let nodeOffset = fbb.endVector(len: schema.fields.count)
174-
175-
// write out buffers
176-
var buffers = [org_apache_arrow_flatbuf_Buffer]()
177-
var bufferOffset = Int(0)
178-
for index in 0..<batch.schema.fields.count {
179-
let column = batch.column(index)
191+
private func writeBufferInfo(
192+
_ fields: [ArrowField],
193+
columns: [ArrowArrayHolder],
194+
bufferOffset: inout Int,
195+
buffers: inout [org_apache_arrow_flatbuf_Buffer],
196+
fbb: inout FlatBufferBuilder
197+
) {
198+
for index in 0..<fields.count {
199+
let column = columns[index]
180200
let colBufferDataSizes = column.getBufferDataSizes()
181201
for var bufferDataSize in colBufferDataSizes {
182202
bufferDataSize = getPadForAlignment(bufferDataSize)
183203
let buffer = org_apache_arrow_flatbuf_Buffer(
184204
offset: Int64(bufferOffset), length: Int64(bufferDataSize))
185205
buffers.append(buffer)
186206
bufferOffset += bufferDataSize
207+
if let nestedType = column.type as? ArrowNestedType {
208+
let structArray = column.array as? StructArray
209+
writeBufferInfo(
210+
nestedType.fields, columns: structArray!.arrowFields!,
211+
bufferOffset: &bufferOffset, buffers: &buffers, fbb: &fbb)
212+
}
187213
}
188214
}
215+
}
189216

217+
private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
218+
let schema = batch.schema
219+
var fbb = FlatBufferBuilder()
220+
221+
// write out field nodes
222+
var fieldNodeOffsets = [Offset]()
223+
fbb.startVector(
224+
schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
225+
writeFieldNodes(schema.fields, columns: batch.columns, offsets: &fieldNodeOffsets, fbb: &fbb)
226+
let nodeOffset = fbb.endVector(len: fieldNodeOffsets.count)
227+
228+
// write out buffers
229+
var buffers = [org_apache_arrow_flatbuf_Buffer]()
230+
var bufferOffset = Int(0)
231+
writeBufferInfo(
232+
schema.fields, columns: batch.columns,
233+
bufferOffset: &bufferOffset, buffers: &buffers,
234+
fbb: &fbb)
190235
org_apache_arrow_flatbuf_RecordBatch.startVectorOfBuffers(batch.schema.fields.count, in: &fbb)
191236
for buffer in buffers.reversed() {
192237
fbb.create(struct: buffer)
@@ -210,15 +255,32 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
210255
return .success((fbb.data, Offset(offset: UInt32(fbb.data.count))))
211256
}
212257

213-
private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result<
214-
Bool, ArrowError
215-
> {
216-
for index in 0..<batch.schema.fields.count {
217-
let column = batch.column(index)
258+
private func writeRecordBatchData(
259+
_ writer: inout DataWriter, fields: [ArrowField],
260+
columns: [ArrowArrayHolder]
261+
)
262+
-> Result<Bool, ArrowError>
263+
{
264+
for index in 0..<fields.count {
265+
let column = columns[index]
218266
let colBufferData = column.getBufferData()
219267
for var bufferData in colBufferData {
220268
addPadForAlignment(&bufferData)
221269
writer.append(bufferData)
270+
if let nestedType = column.type as? ArrowNestedType {
271+
guard let structArray = column.array as? StructArray else {
272+
return .failure(.invalid("Struct type array expected for nested type"))
273+
}
274+
275+
switch writeRecordBatchData(
276+
&writer, fields: nestedType.fields, columns: structArray.arrowFields!)
277+
{
278+
case .success:
279+
continue
280+
case .failure(let error):
281+
return .failure(error)
282+
}
283+
}
222284
}
223285
}
224286

@@ -244,11 +306,10 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
244306
org_apache_arrow_flatbuf_Footer.addVectorOf(recordBatches: rbBlkEnd, &fbb)
245307
let footerOffset = org_apache_arrow_flatbuf_Footer.endFooter(&fbb, start: footerStartOffset)
246308
fbb.finish(offset: footerOffset)
309+
return .success(fbb.data)
247310
case .failure(let error):
248311
return .failure(error)
249312
}
250-
251-
return .success(fbb.data)
252313
}
253314

254315
private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
@@ -285,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
285346
return .success(true)
286347
}
287348

288-
public func writeSteaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
349+
public func writeStreaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
289350
let writer: any DataWriter = InMemDataWriter()
290351
switch toMessage(info.schema) {
291352
case .success(let schemaData):
@@ -363,7 +424,8 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
363424
writer.append(message.0)
364425
addPadForAlignment(&writer)
365426
var dataWriter: any DataWriter = InMemDataWriter()
366-
switch writeRecordBatchData(&dataWriter, batch: batch) {
427+
switch writeRecordBatchData(&dataWriter, fields: batch.schema.fields, columns: batch.columns)
428+
{
367429
case .success:
368430
return .success([
369431
(writer as! InMemDataWriter).data, // swiftlint:disable:this force_cast
@@ -397,3 +459,4 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
397459
return .success(fbb.data)
398460
}
399461
}
462+
// swiftlint:disable:this file_length

Sources/SparkConnect/ArrowWriterHelper.swift

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,77 +25,78 @@ extension Data {
2525
}
2626

2727
func toFBTypeEnum(_ arrowType: ArrowType) -> Result<org_apache_arrow_flatbuf_Type_, ArrowError> {
28-
let infoType = arrowType.info
29-
if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowInt16
30-
|| infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt8
31-
|| infoType == ArrowType.ArrowUInt16 || infoType == ArrowType.ArrowUInt32
32-
|| infoType == ArrowType.ArrowUInt64 || infoType == ArrowType.ArrowInt32
33-
{
28+
let typeId = arrowType.id
29+
switch typeId {
30+
case .int8, .int16, .int32, .int64, .uint8, .uint16, .uint32, .uint64:
3431
return .success(org_apache_arrow_flatbuf_Type_.int)
35-
} else if infoType == ArrowType.ArrowFloat || infoType == ArrowType.ArrowDouble {
32+
case .float, .double:
3633
return .success(org_apache_arrow_flatbuf_Type_.floatingpoint)
37-
} else if infoType == ArrowType.ArrowString {
34+
case .string:
3835
return .success(org_apache_arrow_flatbuf_Type_.utf8)
39-
} else if infoType == ArrowType.ArrowBinary {
36+
case .binary:
4037
return .success(org_apache_arrow_flatbuf_Type_.binary)
41-
} else if infoType == ArrowType.ArrowBool {
38+
case .boolean:
4239
return .success(org_apache_arrow_flatbuf_Type_.bool)
43-
} else if infoType == ArrowType.ArrowDate32 || infoType == ArrowType.ArrowDate64 {
40+
case .date32, .date64:
4441
return .success(org_apache_arrow_flatbuf_Type_.date)
45-
} else if infoType == ArrowType.ArrowTime32 || infoType == ArrowType.ArrowTime64 {
42+
case .time32, .time64:
4643
return .success(org_apache_arrow_flatbuf_Type_.time)
44+
case .strct:
45+
return .success(org_apache_arrow_flatbuf_Type_.struct_)
46+
default:
47+
return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(typeId)"))
4748
}
48-
return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(infoType)"))
4949
}
5050

51-
func toFBType( // swiftlint:disable:this cyclomatic_complexity
51+
func toFBType( // swiftlint:disable:this cyclomatic_complexity function_body_length
5252
_ fbb: inout FlatBufferBuilder,
5353
arrowType: ArrowType
5454
) -> Result<Offset, ArrowError> {
5555
let infoType = arrowType.info
56-
if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowUInt8 {
56+
switch arrowType.id {
57+
case .int8, .uint8:
5758
return .success(
5859
org_apache_arrow_flatbuf_Int.createInt(
5960
&fbb, bitWidth: 8, isSigned: infoType == ArrowType.ArrowInt8))
60-
} else if infoType == ArrowType.ArrowInt16 || infoType == ArrowType.ArrowUInt16 {
61+
case .int16, .uint16:
6162
return .success(
6263
org_apache_arrow_flatbuf_Int.createInt(
6364
&fbb, bitWidth: 16, isSigned: infoType == ArrowType.ArrowInt16))
64-
} else if infoType == ArrowType.ArrowInt32 || infoType == ArrowType.ArrowUInt32 {
65+
case .int32, .uint32:
6566
return .success(
6667
org_apache_arrow_flatbuf_Int.createInt(
6768
&fbb, bitWidth: 32, isSigned: infoType == ArrowType.ArrowInt32))
68-
} else if infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt64 {
69+
case .int64, .uint64:
6970
return .success(
7071
org_apache_arrow_flatbuf_Int.createInt(
7172
&fbb, bitWidth: 64, isSigned: infoType == ArrowType.ArrowInt64))
72-
} else if infoType == ArrowType.ArrowFloat {
73+
case .float:
7374
return .success(
7475
org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .single))
75-
} else if infoType == ArrowType.ArrowDouble {
76+
case .double:
7677
return .success(
7778
org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .double))
78-
} else if infoType == ArrowType.ArrowString {
79+
case .string:
7980
return .success(
8081
org_apache_arrow_flatbuf_Utf8.endUtf8(
8182
&fbb, start: org_apache_arrow_flatbuf_Utf8.startUtf8(&fbb)))
82-
} else if infoType == ArrowType.ArrowBinary {
83+
case .binary:
8384
return .success(
8485
org_apache_arrow_flatbuf_Binary.endBinary(
8586
&fbb, start: org_apache_arrow_flatbuf_Binary.startBinary(&fbb)))
86-
} else if infoType == ArrowType.ArrowBool {
87+
case .boolean:
8788
return .success(
8889
org_apache_arrow_flatbuf_Bool.endBool(
8990
&fbb, start: org_apache_arrow_flatbuf_Bool.startBool(&fbb)))
90-
} else if infoType == ArrowType.ArrowDate32 {
91+
case .date32:
9192
let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
9293
org_apache_arrow_flatbuf_Date.add(unit: .day, &fbb)
9394
return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset))
94-
} else if infoType == ArrowType.ArrowDate64 {
95+
case .date64:
9596
let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
9697
org_apache_arrow_flatbuf_Date.add(unit: .millisecond, &fbb)
9798
return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset))
98-
} else if infoType == ArrowType.ArrowTime32 {
99+
case .time32:
99100
let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
100101
if let timeType = arrowType as? ArrowTypeTime32 {
101102
org_apache_arrow_flatbuf_Time.add(
@@ -104,7 +105,7 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity
104105
}
105106

106107
return .failure(.invalid("Unable to case to Time32"))
107-
} else if infoType == ArrowType.ArrowTime64 {
108+
case .time64:
108109
let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
109110
if let timeType = arrowType as? ArrowTypeTime64 {
110111
org_apache_arrow_flatbuf_Time.add(
@@ -113,9 +114,12 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity
113114
}
114115

115116
return .failure(.invalid("Unable to case to Time64"))
117+
case .strct:
118+
let startOffset = org_apache_arrow_flatbuf_Struct_.startStruct_(&fbb)
119+
return .success(org_apache_arrow_flatbuf_Struct_.endStruct_(&fbb, start: startOffset))
120+
default:
121+
return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)"))
116122
}
117-
118-
return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)"))
119123
}
120124

121125
func addPadForAlignment(_ data: inout Data, alignment: Int = 8) {

Sources/SparkConnect/ProtoUtil.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import Foundation
1919

20-
func fromProto( // swiftlint:disable:this cyclomatic_complexity
20+
func fromProto( // swiftlint:disable:this cyclomatic_complexity function_body_length
2121
field: org_apache_arrow_flatbuf_Field
2222
) -> ArrowField {
2323
let type = field.typeType
@@ -74,7 +74,13 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity
7474
arrowType = ArrowTypeTime64(arrowUnit)
7575
}
7676
case .struct_:
77-
arrowType = ArrowType(ArrowType.ArrowStruct)
77+
var children = [ArrowField]()
78+
for index in 0..<field.childrenCount {
79+
let childField = field.children(at: index)!
80+
children.append(fromProto(field: childField))
81+
}
82+
83+
arrowType = ArrowNestedType(ArrowType.ArrowStruct, fields: children)
7884
default:
7985
arrowType = ArrowType(ArrowType.ArrowUnknown)
8086
}

0 commit comments

Comments
 (0)