Skip to content

Commit a0252f4

Browse files
committed
apacheGH-43170: Add StructArray support to ArrowWriter
1 parent 1943911 commit a0252f4

File tree

5 files changed

+347
-94
lines changed

5 files changed

+347
-94
lines changed

swift/Arrow/Sources/Arrow/ArrowWriter.swift

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,30 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
7171
public init() {}
7272

7373
private func writeField(_ fbb: inout FlatBufferBuilder, field: ArrowField) -> Result<Offset, ArrowError> {
74+
var fieldsOffset: Offset?
75+
if let nestedField = field.type as? ArrowNestedType {
76+
var offsets = [Offset]()
77+
for field in nestedField.fields {
78+
switch writeField(&fbb, field: field) {
79+
case .success(let offset):
80+
offsets.append(offset)
81+
case .failure(let error):
82+
return .failure(error)
83+
}
84+
}
85+
86+
fieldsOffset = fbb.createVector(ofOffsets: offsets)
87+
}
88+
7489
let nameOffset = fbb.create(string: field.name)
7590
let fieldTypeOffsetResult = toFBType(&fbb, arrowType: field.type)
7691
let startOffset = org_apache_arrow_flatbuf_Field.startField(&fbb)
7792
org_apache_arrow_flatbuf_Field.add(name: nameOffset, &fbb)
7893
org_apache_arrow_flatbuf_Field.add(nullable: field.isNullable, &fbb)
94+
if let childrenOffset = fieldsOffset {
95+
org_apache_arrow_flatbuf_Field.addVectorOf(children: childrenOffset, &fbb)
96+
}
97+
7998
switch toFBTypeEnum(field.type) {
8099
case .success(let type):
81100
org_apache_arrow_flatbuf_Field.add(typeType: type, &fbb)
@@ -101,7 +120,6 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
101120
case .failure(let error):
102121
return .failure(error)
103122
}
104-
105123
}
106124

107125
let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets)
@@ -126,7 +144,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
126144
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
127145
withUnsafeBytes(of: rbResult.1.o.littleEndian) {writer.append(Data($0))}
128146
writer.append(rbResult.0)
129-
switch writeRecordBatchData(&writer, batch: batch) {
147+
switch writeRecordBatchData(&writer, fields: batch.schema.fields, columns: batch.columns) {
130148
case .success:
131149
rbBlocks.append(
132150
org_apache_arrow_flatbuf_Block(offset: Int64(startIndex),
@@ -143,37 +161,59 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
143161
return .success(rbBlocks)
144162
}
145163

146-
private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
147-
let schema = batch.schema
148-
var fbb = FlatBufferBuilder()
149-
150-
// write out field nodes
151-
var fieldNodeOffsets = [Offset]()
152-
fbb.startVector(schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
153-
for index in (0 ..< schema.fields.count).reversed() {
154-
let column = batch.column(index)
164+
private func writeFieldNodes(_ fields: [ArrowField], columns: [ArrowArrayHolder], offsets: inout [Offset],
165+
fbb: inout FlatBufferBuilder) {
166+
for index in (0 ..< fields.count).reversed() {
167+
let column = columns[index]
155168
let fieldNode =
156169
org_apache_arrow_flatbuf_FieldNode(length: Int64(column.length),
157170
nullCount: Int64(column.nullCount))
158-
fieldNodeOffsets.append(fbb.create(struct: fieldNode))
171+
offsets.append(fbb.create(struct: fieldNode))
172+
if let nestedType = column.type as? ArrowNestedType {
173+
let structArray = column.array as? StructArray
174+
writeFieldNodes(nestedType.fields, columns: structArray!.arrowFields!, offsets: &offsets, fbb: &fbb)
175+
}
159176
}
177+
}
160178

161-
let nodeOffset = fbb.endVector(len: schema.fields.count)
162-
163-
// write out buffers
164-
var buffers = [org_apache_arrow_flatbuf_Buffer]()
165-
var bufferOffset = Int(0)
166-
for index in 0 ..< batch.schema.fields.count {
167-
let column = batch.column(index)
179+
private func writeBufferInfo(_ fields: [ArrowField],
180+
columns: [ArrowArrayHolder],
181+
bufferOffset: inout Int,
182+
buffers: inout [org_apache_arrow_flatbuf_Buffer],
183+
fbb: inout FlatBufferBuilder) {
184+
for index in 0 ..< fields.count {
185+
let column = columns[index]
168186
let colBufferDataSizes = column.getBufferDataSizes()
169187
for var bufferDataSize in colBufferDataSizes {
170188
bufferDataSize = getPadForAlignment(bufferDataSize)
171189
let buffer = org_apache_arrow_flatbuf_Buffer(offset: Int64(bufferOffset), length: Int64(bufferDataSize))
172190
buffers.append(buffer)
173191
bufferOffset += bufferDataSize
192+
if let nestedType = column.type as? ArrowNestedType {
193+
let structArray = column.array as? StructArray
194+
writeBufferInfo(nestedType.fields, columns: structArray!.arrowFields!,
195+
bufferOffset: &bufferOffset, buffers: &buffers, fbb: &fbb)
196+
}
174197
}
175198
}
199+
}
176200

201+
private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
202+
let schema = batch.schema
203+
var fbb = FlatBufferBuilder()
204+
205+
// write out field nodes
206+
var fieldNodeOffsets = [Offset]()
207+
fbb.startVector(schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
208+
writeFieldNodes(schema.fields, columns: batch.columns, offsets: &fieldNodeOffsets, fbb: &fbb)
209+
let nodeOffset = fbb.endVector(len: fieldNodeOffsets.count)
210+
211+
// write out buffers
212+
var buffers = [org_apache_arrow_flatbuf_Buffer]()
213+
var bufferOffset = Int(0)
214+
writeBufferInfo(schema.fields, columns: batch.columns,
215+
bufferOffset: &bufferOffset, buffers: &buffers,
216+
fbb: &fbb)
177217
org_apache_arrow_flatbuf_RecordBatch.startVectorOfBuffers(batch.schema.fields.count, in: &fbb)
178218
for buffer in buffers.reversed() {
179219
fbb.create(struct: buffer)
@@ -196,13 +236,28 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
196236
return .success((fbb.data, Offset(offset: UInt32(fbb.data.count))))
197237
}
198238

199-
private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result<Bool, ArrowError> {
200-
for index in 0 ..< batch.schema.fields.count {
201-
let column = batch.column(index)
239+
private func writeRecordBatchData(
240+
_ writer: inout DataWriter, fields: [ArrowField],
241+
columns: [ArrowArrayHolder])
242+
-> Result<Bool, ArrowError> {
243+
for index in 0 ..< fields.count {
244+
let column = columns[index]
202245
let colBufferData = column.getBufferData()
203246
for var bufferData in colBufferData {
204247
addPadForAlignment(&bufferData)
205248
writer.append(bufferData)
249+
if let nestedType = column.type as? ArrowNestedType {
250+
guard let structArray = column.array as? StructArray else {
251+
return .failure(.invalid("Struct type array expected for nested type"))
252+
}
253+
254+
switch writeRecordBatchData(&writer, fields: nestedType.fields, columns: structArray.arrowFields!) {
255+
case .success:
256+
continue
257+
case .failure(let error):
258+
return .failure(error)
259+
}
260+
}
206261
}
207262
}
208263

@@ -226,11 +281,10 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
226281
org_apache_arrow_flatbuf_Footer.addVectorOf(recordBatches: rbBlkEnd, &fbb)
227282
let footerOffset = org_apache_arrow_flatbuf_Footer.endFooter(&fbb, start: footerStartOffset)
228283
fbb.finish(offset: footerOffset)
284+
return .success(fbb.data)
229285
case .failure(let error):
230286
return .failure(error)
231287
}
232-
233-
return .success(fbb.data)
234288
}
235289

236290
private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
@@ -265,7 +319,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
265319
return .success(true)
266320
}
267321

268-
public func writeSteaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
322+
public func writeStreaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
269323
let writer: any DataWriter = InMemDataWriter()
270324
switch toMessage(info.schema) {
271325
case .success(let schemaData):
@@ -343,7 +397,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
343397
writer.append(message.0)
344398
addPadForAlignment(&writer)
345399
var dataWriter: any DataWriter = InMemDataWriter()
346-
switch writeRecordBatchData(&dataWriter, batch: batch) {
400+
switch writeRecordBatchData(&dataWriter, fields: batch.schema.fields, columns: batch.columns) {
347401
case .success:
348402
return .success([
349403
(writer as! InMemDataWriter).data, // swiftlint:disable:this force_cast
@@ -377,3 +431,4 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
377431
return .success(fbb.data)
378432
}
379433
}
434+
// swiftlint:disable:this file_length

swift/Arrow/Sources/Arrow/ArrowWriterHelper.swift

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,85 +25,90 @@ extension Data {
2525
}
2626

2727
func toFBTypeEnum(_ arrowType: ArrowType) -> Result<org_apache_arrow_flatbuf_Type_, ArrowError> {
28-
let infoType = arrowType.info
29-
if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowInt16 ||
30-
infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt8 ||
31-
infoType == ArrowType.ArrowUInt16 || infoType == ArrowType.ArrowUInt32 ||
32-
infoType == ArrowType.ArrowUInt64 || infoType == ArrowType.ArrowInt32 {
28+
let typeId = arrowType.id
29+
switch typeId {
30+
case .int8, .int16, .int32, .int64, .uint8, .uint16, .uint32, .uint64:
3331
return .success(org_apache_arrow_flatbuf_Type_.int)
34-
} else if infoType == ArrowType.ArrowFloat || infoType == ArrowType.ArrowDouble {
32+
case .float, .double:
3533
return .success(org_apache_arrow_flatbuf_Type_.floatingpoint)
36-
} else if infoType == ArrowType.ArrowString {
34+
case .string:
3735
return .success(org_apache_arrow_flatbuf_Type_.utf8)
38-
} else if infoType == ArrowType.ArrowBinary {
36+
case .binary:
3937
return .success(org_apache_arrow_flatbuf_Type_.binary)
40-
} else if infoType == ArrowType.ArrowBool {
38+
case .boolean:
4139
return .success(org_apache_arrow_flatbuf_Type_.bool)
42-
} else if infoType == ArrowType.ArrowDate32 || infoType == ArrowType.ArrowDate64 {
40+
case .date32, .date64:
4341
return .success(org_apache_arrow_flatbuf_Type_.date)
44-
} else if infoType == ArrowType.ArrowTime32 || infoType == ArrowType.ArrowTime64 {
42+
case .time32, .time64:
4543
return .success(org_apache_arrow_flatbuf_Type_.time)
44+
case .strct:
45+
return .success(org_apache_arrow_flatbuf_Type_.struct_)
46+
default:
47+
return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(typeId)"))
4648
}
47-
return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(infoType)"))
4849
}
4950

50-
func toFBType( // swiftlint:disable:this cyclomatic_complexity
51+
func toFBType( // swiftlint:disable:this cyclomatic_complexity function_body_length
5152
_ fbb: inout FlatBufferBuilder,
5253
arrowType: ArrowType
5354
) -> Result<Offset, ArrowError> {
5455
let infoType = arrowType.info
55-
if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowUInt8 {
56+
switch arrowType.id {
57+
case .int8, .uint8:
5658
return .success(org_apache_arrow_flatbuf_Int.createInt(
5759
&fbb, bitWidth: 8, isSigned: infoType == ArrowType.ArrowInt8))
58-
} else if infoType == ArrowType.ArrowInt16 || infoType == ArrowType.ArrowUInt16 {
60+
case .int16, .uint16:
5961
return .success(org_apache_arrow_flatbuf_Int.createInt(
6062
&fbb, bitWidth: 16, isSigned: infoType == ArrowType.ArrowInt16))
61-
} else if infoType == ArrowType.ArrowInt32 || infoType == ArrowType.ArrowUInt32 {
63+
case .int32, .uint32:
6264
return .success(org_apache_arrow_flatbuf_Int.createInt(
6365
&fbb, bitWidth: 32, isSigned: infoType == ArrowType.ArrowInt32))
64-
} else if infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt64 {
66+
case .int64, .uint64:
6567
return .success(org_apache_arrow_flatbuf_Int.createInt(
6668
&fbb, bitWidth: 64, isSigned: infoType == ArrowType.ArrowInt64))
67-
} else if infoType == ArrowType.ArrowFloat {
69+
case .float:
6870
return .success(org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .single))
69-
} else if infoType == ArrowType.ArrowDouble {
71+
case .double:
7072
return .success(org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .double))
71-
} else if infoType == ArrowType.ArrowString {
73+
case .string:
7274
return .success(org_apache_arrow_flatbuf_Utf8.endUtf8(
7375
&fbb, start: org_apache_arrow_flatbuf_Utf8.startUtf8(&fbb)))
74-
} else if infoType == ArrowType.ArrowBinary {
76+
case .binary:
7577
return .success(org_apache_arrow_flatbuf_Binary.endBinary(
7678
&fbb, start: org_apache_arrow_flatbuf_Binary.startBinary(&fbb)))
77-
} else if infoType == ArrowType.ArrowBool {
79+
case .boolean:
7880
return .success(org_apache_arrow_flatbuf_Bool.endBool(
7981
&fbb, start: org_apache_arrow_flatbuf_Bool.startBool(&fbb)))
80-
} else if infoType == ArrowType.ArrowDate32 {
82+
case .date32:
8183
let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
8284
org_apache_arrow_flatbuf_Date.add(unit: .day, &fbb)
8385
return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset))
84-
} else if infoType == ArrowType.ArrowDate64 {
86+
case .date64:
8587
let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
8688
org_apache_arrow_flatbuf_Date.add(unit: .millisecond, &fbb)
8789
return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset))
88-
} else if infoType == ArrowType.ArrowTime32 {
90+
case .time32:
8991
let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
9092
if let timeType = arrowType as? ArrowTypeTime32 {
9193
org_apache_arrow_flatbuf_Time.add(unit: timeType.unit == .seconds ? .second : .millisecond, &fbb)
9294
return .success(org_apache_arrow_flatbuf_Time.endTime(&fbb, start: startOffset))
9395
}
9496

9597
return .failure(.invalid("Unable to case to Time32"))
96-
} else if infoType == ArrowType.ArrowTime64 {
98+
case .time64:
9799
let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
98100
if let timeType = arrowType as? ArrowTypeTime64 {
99101
org_apache_arrow_flatbuf_Time.add(unit: timeType.unit == .microseconds ? .microsecond : .nanosecond, &fbb)
100102
return .success(org_apache_arrow_flatbuf_Time.endTime(&fbb, start: startOffset))
101103
}
102104

103105
return .failure(.invalid("Unable to case to Time64"))
106+
case .strct:
107+
let startOffset = org_apache_arrow_flatbuf_Struct_.startStruct_(&fbb)
108+
return .success(org_apache_arrow_flatbuf_Struct_.endStruct_(&fbb, start: startOffset))
109+
default:
110+
return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)"))
104111
}
105-
106-
return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)"))
107112
}
108113

109114
func addPadForAlignment(_ data: inout Data, alignment: Int = 8) {

swift/Arrow/Sources/Arrow/ProtoUtil.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import Foundation
1919

20-
func fromProto( // swiftlint:disable:this cyclomatic_complexity
20+
func fromProto( // swiftlint:disable:this cyclomatic_complexity function_body_length
2121
field: org_apache_arrow_flatbuf_Field
2222
) -> ArrowField {
2323
let type = field.typeType
@@ -65,7 +65,13 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity
6565
arrowType = ArrowTypeTime64(arrowUnit)
6666
}
6767
case .struct_:
68-
arrowType = ArrowType(ArrowType.ArrowStruct)
68+
var children = [ArrowField]()
69+
for index in 0..<field.childrenCount {
70+
let childField = field.children(at: index)!
71+
children.append(fromProto(field: childField))
72+
}
73+
74+
arrowType = ArrowNestedType(ArrowType.ArrowStruct, fields: children)
6975
default:
7076
arrowType = ArrowType(ArrowType.ArrowUnknown)
7177
}

0 commit comments

Comments
 (0)