Skip to content

Commit cfa33b8

Browse files
committed
[SPARK-52301] Support Decimal type
### What changes were proposed in this pull request? This PR aims to support `Decimal` type in `Row` and `DataFrame.collect`. ### Why are the changes needed? Previously, `Decimal` is supported inside `Spark Connect Server`-side operation only, e.g. `DataFrame.show`. ### Does this PR introduce _any_ user-facing change? No, this is an additional type. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #174 from dongjoon-hyun/SPARK-52301. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 8600b46 commit cfa33b8

14 files changed

+162
-8
lines changed

Sources/SparkConnect/ArrowArray.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ public class ArrowArrayHolderImpl: ArrowArrayHolder {
101101
return try ArrowArrayHolderImpl(FixedArray<Double>(with))
102102
case .float:
103103
return try ArrowArrayHolderImpl(FixedArray<Float>(with))
104+
case .decimal128:
105+
return try ArrowArrayHolderImpl(FixedArray<Decimal>(with))
104106
case .date32:
105107
return try ArrowArrayHolderImpl(Date32Array(with))
106108
case .date64:
@@ -247,6 +249,25 @@ public class Time32Array: FixedArray<Time32> {}
247249
/// @nodoc
248250
public class Time64Array: FixedArray<Time64> {}
249251

252+
/// @nodoc
253+
public class Decimal128Array: FixedArray<Decimal> {
254+
public override subscript(_ index: UInt) -> Decimal? {
255+
if self.arrowData.isNull(index) {
256+
return nil
257+
}
258+
let scale: Int32 = switch self.arrowData.type.id {
259+
case .decimal128(_, let scale):
260+
scale
261+
default:
262+
18
263+
}
264+
let byteOffset = self.arrowData.stride * Int(index)
265+
let value = self.arrowData.buffers[1].rawPointer.advanced(by: byteOffset).load(
266+
as: UInt64.self)
267+
return Decimal(value) / pow(10, Int(scale))
268+
}
269+
}
270+
250271
/// @nodoc
251272
public class BinaryArray: ArrowArray<Data> {
252273
public struct Options {

Sources/SparkConnect/ArrowArrayBuilder.swift

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ public class Time64ArrayBuilder: ArrowArrayBuilder<FixedBufferBuilder<Time64>, T
122122
}
123123
}
124124

125+
public class Decimal128ArrayBuilder: ArrowArrayBuilder<FixedBufferBuilder<Decimal>, Decimal128Array> {
126+
fileprivate convenience init(precision: Int32, scale: Int32) throws {
127+
try self.init(ArrowTypeDecimal128(precision: precision, scale: scale))
128+
}
129+
}
130+
125131
public class StructArrayBuilder: ArrowArrayBuilder<StructBufferBuilder, StructArray> {
126132
let builders: [any ArrowArrayHolderBuilder]
127133
let fields: [ArrowField]
@@ -202,6 +208,8 @@ public class ArrowArrayBuilders {
202208
return try ArrowArrayBuilders.loadBoolArrayBuilder()
203209
} else if builderType == Date.self || builderType == Date?.self {
204210
return try ArrowArrayBuilders.loadDate64ArrayBuilder()
211+
} else if builderType == Decimal.self || builderType == Decimal?.self {
212+
return try ArrowArrayBuilders.loadDecimal128ArrayBuilder(38, 18)
205213
} else {
206214
throw ArrowError.invalid("Invalid type for builder: \(builderType)")
207215
}
@@ -214,7 +222,7 @@ public class ArrowArrayBuilders {
214222
|| type == Bool?.self || type == Bool.self || type == Int8.self || type == Int16.self
215223
|| type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self
216224
|| type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self
217-
|| type == Float.self || type == Date.self
225+
|| type == Float.self || type == Date.self || type == Decimal.self || type == Decimal?.self
218226
}
219227

220228
public static func loadStructArrayBuilderForType<T>(_ obj: T) throws -> StructArrayBuilder {
@@ -279,6 +287,11 @@ public class ArrowArrayBuilders {
279287
throw ArrowError.invalid("Expected arrow type for \(arrowType.id) not found")
280288
}
281289
return try Time64ArrayBuilder(timeType.unit)
290+
case .decimal128:
291+
guard let decimalType = arrowType as? ArrowTypeDecimal128 else {
292+
throw ArrowError.invalid("Expected ArrowTypeDecimal128 for decimal128 type")
293+
}
294+
return try Decimal128ArrayBuilder(precision: decimalType.precision, scale: decimalType.scale)
282295
default:
283296
throw ArrowError.unknownType("Builder not found for arrow type: \(arrowType.id)")
284297
}
@@ -306,6 +319,8 @@ public class ArrowArrayBuilders {
306319
return try NumberArrayBuilder<T>()
307320
} else if type == Double.self {
308321
return try NumberArrayBuilder<T>()
322+
} else if type == Decimal.self {
323+
return try NumberArrayBuilder<T>()
309324
} else {
310325
throw ArrowError.unknownType("Type is invalid for NumberArrayBuilder")
311326
}
@@ -338,4 +353,11 @@ public class ArrowArrayBuilders {
338353
public static func loadTime64ArrayBuilder(_ unit: ArrowTime64Unit) throws -> Time64ArrayBuilder {
339354
return try Time64ArrayBuilder(unit)
340355
}
356+
357+
public static func loadDecimal128ArrayBuilder(
358+
_ precision: Int32 = 38,
359+
_ scale: Int32 = 18
360+
) throws -> Decimal128ArrayBuilder {
361+
return try Decimal128ArrayBuilder(precision: precision, scale: scale)
362+
}
341363
}

Sources/SparkConnect/ArrowBufferBuilder.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ public class FixedBufferBuilder<T>: ValuesBufferBuilder<T>, ArrowBufferBuilder {
142142
return Float(0) as! T // swiftlint:disable:this force_cast
143143
} else if type == Double.self {
144144
return Double(0) as! T // swiftlint:disable:this force_cast
145+
} else if type == Decimal.self {
146+
return Decimal(0) as! T // swiftlint:disable:this force_cast
145147
}
146148

147149
throw ArrowError.unknownType("Unable to determine default value")

Sources/SparkConnect/ArrowDecoder.swift

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {
160160
|| type == Bool?.self || type == Bool.self || type == Int8.self || type == Int16.self
161161
|| type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self
162162
|| type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self
163-
|| type == Float.self || type == Date.self
163+
|| type == Float.self || type == Date.self || type == Decimal.self || type == Decimal?.self
164164
{
165165
defer { increment() }
166166
return try self.decoder.doDecode(self.currentIndex)!
@@ -260,8 +260,12 @@ private struct ArrowKeyedDecoding<Key: CodingKey>: KeyedDecodingContainerProtoco
260260
return try self.decoder.doDecode(key)!
261261
}
262262

263+
func decode(_ type: Decimal.Type, forKey key: Key) throws -> Decimal {
264+
return try self.decoder.doDecode(key)!
265+
}
266+
263267
func decode<T>(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable {
264-
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
268+
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self {
265269
return try self.decoder.doDecode(key)!
266270
} else {
267271
throw ArrowError.invalid("Type \(type) is currently not supported")
@@ -363,8 +367,12 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {
363367
return try self.decoder.doDecode(self.decoder.singleRBCol)!
364368
}
365369

370+
func decode(_ type: Decimal.Type) throws -> Decimal {
371+
return try self.decoder.doDecode(self.decoder.singleRBCol)!
372+
}
373+
366374
func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
367-
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
375+
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self {
368376
return try self.decoder.doDecode(self.decoder.singleRBCol)!
369377
} else {
370378
throw ArrowError.invalid("Type \(type) is currently not supported")

Sources/SparkConnect/ArrowReaderHelper.swift

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ private func makeStringHolder(
4949
}
5050
}
5151

52+
private func makeDecimalHolder(
53+
_ field: ArrowField,
54+
buffers: [ArrowBuffer],
55+
nullCount: UInt
56+
) -> Result<ArrowArrayHolder, ArrowError> {
57+
do {
58+
let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount)
59+
return .success(ArrowArrayHolderImpl(try Decimal128Array(arrowData)))
60+
} catch let error as ArrowError {
61+
return .failure(error)
62+
} catch {
63+
return .failure(.unknownError("\(error)"))
64+
}
65+
}
66+
5267
private func makeDateHolder(
5368
_ field: ArrowField,
5469
buffers: [ArrowBuffer],
@@ -183,6 +198,8 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
183198
return makeFixedHolder(Int64.self, field: field, buffers: buffers, nullCount: nullCount)
184199
case .uint64:
185200
return makeFixedHolder(UInt64.self, field: field, buffers: buffers, nullCount: nullCount)
201+
case .decimal128:
202+
return makeDecimalHolder(field, buffers: buffers, nullCount: nullCount)
186203
case .boolean:
187204
return makeBoolHolder(buffers, nullCount: nullCount)
188205
case .float:
@@ -217,7 +234,7 @@ func makeBuffer(
217234

218235
func isFixedPrimitive(_ type: org_apache_arrow_flatbuf_Type_) -> Bool {
219236
switch type {
220-
case .int, .bool, .floatingpoint, .date, .time:
237+
case .int, .bool, .floatingpoint, .date, .time, .decimal:
221238
return true
222239
default:
223240
return false
@@ -266,6 +283,12 @@ func findArrowType( // swiftlint:disable:this cyclomatic_complexity function_bo
266283
default:
267284
return ArrowType(ArrowType.ArrowUnknown)
268285
}
286+
case .decimal:
287+
let dataType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)!
288+
if dataType.bitWidth == 128 {
289+
return ArrowType(ArrowType.ArrowDecimal128)
290+
}
291+
return ArrowType(ArrowType.ArrowUnknown)
269292
case .utf8:
270293
return ArrowType(ArrowType.ArrowString)
271294
case .binary:

Sources/SparkConnect/ArrowType.swift

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ public enum ArrowError: Error {
4242
case invalid(String)
4343
}
4444

45-
public enum ArrowTypeId: Sendable {
45+
public enum ArrowTypeId: Sendable, Equatable {
4646
case binary
4747
case boolean
4848
case date32
4949
case date64
5050
case dateType
51-
case decimal128
51+
case decimal128(_ precision: Int32, _ scale: Int32)
5252
case decimal256
5353
case dictionary
5454
case double
@@ -129,6 +129,23 @@ public class ArrowTypeTime64: ArrowType {
129129
}
130130
}
131131

132+
public class ArrowTypeDecimal128: ArrowType {
133+
let precision: Int32
134+
let scale: Int32
135+
136+
public init(precision: Int32, scale: Int32) {
137+
self.precision = precision
138+
self.scale = scale
139+
super.init(ArrowType.ArrowDecimal128)
140+
}
141+
142+
public override var cDataFormatId: String {
143+
get throws {
144+
return "d:\(precision),\(scale)"
145+
}
146+
}
147+
}
148+
132149
/// @nodoc
133150
public class ArrowNestedType: ArrowType {
134151
let fields: [ArrowField]
@@ -156,6 +173,7 @@ public class ArrowType {
156173
public static let ArrowBool = Info.primitiveInfo(ArrowTypeId.boolean)
157174
public static let ArrowDate32 = Info.primitiveInfo(ArrowTypeId.date32)
158175
public static let ArrowDate64 = Info.primitiveInfo(ArrowTypeId.date64)
176+
public static let ArrowDecimal128 = Info.primitiveInfo(ArrowTypeId.decimal128(38, 18))
159177
public static let ArrowBinary = Info.variableInfo(ArrowTypeId.binary)
160178
public static let ArrowTime32 = Info.timeInfo(ArrowTypeId.time32)
161179
public static let ArrowTime64 = Info.timeInfo(ArrowTypeId.time64)
@@ -216,6 +234,8 @@ public class ArrowType {
216234
return ArrowType.ArrowFloat
217235
} else if type == Double.self {
218236
return ArrowType.ArrowDouble
237+
} else if type == Decimal.self {
238+
return ArrowType.ArrowDecimal128
219239
} else {
220240
return ArrowType.ArrowUnknown
221241
}
@@ -242,6 +262,8 @@ public class ArrowType {
242262
return ArrowType.ArrowFloat
243263
} else if type == Double.self {
244264
return ArrowType.ArrowDouble
265+
} else if type == Decimal.self {
266+
return ArrowType.ArrowDecimal128
245267
} else {
246268
return ArrowType.ArrowUnknown
247269
}
@@ -271,6 +293,8 @@ public class ArrowType {
271293
return MemoryLayout<Float>.stride
272294
case .double:
273295
return MemoryLayout<Double>.stride
296+
case .decimal128:
297+
return 16 // Decimal 128 (= 16 * 8) bits
274298
case .boolean:
275299
return MemoryLayout<Bool>.stride
276300
case .date32:
@@ -315,6 +339,8 @@ public class ArrowType {
315339
return "f"
316340
case ArrowTypeId.double:
317341
return "g"
342+
case ArrowTypeId.decimal128(let precision, let scale):
343+
return "d:\(precision),\(scale)"
318344
case ArrowTypeId.boolean:
319345
return "b"
320346
case ArrowTypeId.date32:
@@ -344,6 +370,7 @@ public class ArrowType {
344370
public static func fromCDataFormatId( // swiftlint:disable:this cyclomatic_complexity
345371
_ from: String
346372
) throws -> ArrowType {
373+
let REGEX_DECIMAL_TYPE = /^d:(\d+),(\d+)$/
347374
if from == "c" {
348375
return ArrowType(ArrowType.ArrowInt8)
349376
} else if from == "s" {
@@ -364,6 +391,10 @@ public class ArrowType {
364391
return ArrowType(ArrowType.ArrowFloat)
365392
} else if from == "g" {
366393
return ArrowType(ArrowType.ArrowDouble)
394+
} else if from.contains(REGEX_DECIMAL_TYPE) {
395+
let match = from.firstMatch(of: REGEX_DECIMAL_TYPE)!
396+
let decimalType = ArrowTypeId.decimal128(Int32(match.1)!, Int32(match.2)!)
397+
return ArrowType(Info.primitiveInfo(decimalType))
367398
} else if from == "b" {
368399
return ArrowType(ArrowType.ArrowBool)
369400
} else if from == "tdD" {

Sources/SparkConnect/DataFrame.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ public actor DataFrame: Sendable {
408408
values.append(array.asAny(i) as? Float)
409409
case .primitiveInfo(.double):
410410
values.append(array.asAny(i) as? Double)
411+
case .primitiveInfo(.decimal128):
412+
values.append(array.asAny(i) as? Decimal)
411413
case .primitiveInfo(.date32):
412414
values.append(array.asAny(i) as! Date)
413415
case ArrowType.ArrowBinary:

Sources/SparkConnect/ProtoUtil.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity
4444
} else if floatType.precision == .double {
4545
arrowType = ArrowType(ArrowType.ArrowDouble)
4646
}
47+
case .decimal:
48+
let decimalType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)!
49+
if decimalType.bitWidth == 128 && decimalType.precision <= 38 {
50+
let arrowDecimal128 = ArrowTypeId.decimal128(decimalType.precision, decimalType.scale)
51+
arrowType = ArrowType(ArrowType.Info.primitiveInfo(arrowDecimal128))
52+
} else {
53+
// Unsupport yet
54+
arrowType = ArrowType(ArrowType.ArrowUnknown)
55+
}
4756
case .utf8:
4857
arrowType = ArrowType(ArrowType.ArrowString)
4958
case .binary:

Sources/SparkConnect/Row.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ public struct Row: Sendable, Equatable {
6969
return a == b
7070
} else if let a = x as? Double, let b = y as? Double {
7171
return a == b
72+
} else if let a = x as? Decimal, let b = y as? Decimal {
73+
return a == b
7274
} else if let a = x as? String, let b = y as? String {
7375
return a == b
7476
} else {

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,25 @@ struct DataFrameTests {
893893
}
894894
await spark.stop()
895895
}
896+
897+
@Test
898+
func decimal() async throws {
899+
let spark = try await SparkSession.builder.getOrCreate()
900+
let df = try await spark.sql(
901+
"""
902+
SELECT * FROM VALUES
903+
(1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)),
904+
(2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL))
905+
""")
906+
#expect(try await df.dtypes.map { $0.1 } ==
907+
["decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)"])
908+
let expected = [
909+
Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)),
910+
Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil)
911+
]
912+
#expect(try await df.collect() == expected)
913+
await spark.stop()
914+
}
896915
#endif
897916

898917
@Test

0 commit comments

Comments
 (0)