Skip to content

Commit 97cdaa5

Browse files
committed
[SPARK-51679] Support dtypes for DataFrame
### What changes were proposed in this pull request? This PR aims to support `dtypes` for `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No, this is a new addition to the unreleased version. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #36 from dongjoon-hyun/SPARK-51679. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 0b087ee commit 97cdaa5

File tree

5 files changed

+167
-0
lines changed

5 files changed

+167
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ public actor DataFrame: Sendable {
9595
return try self.schema!.jsonString()
9696
}
9797

98+
var dtypes: [(String, String)] {
99+
get async throws {
100+
try await analyzePlanIfNeeded()
101+
return try self.schema!.struct.fields.map { ($0.name, try $0.dataType.simpleString) }
102+
}
103+
}
104+
98105
private func analyzePlanIfNeeded() async throws {
99106
if self.schema != nil {
100107
return

Sources/SparkConnect/Extension.swift

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,127 @@ extension SparkSession: Equatable {
9595
return lhs.sessionID == rhs.sessionID
9696
}
9797
}
98+
99+
extension YearMonthInterval {
100+
func fieldToString(_ field: Int32) throws -> String {
101+
return switch field {
102+
case 0: "year"
103+
case 1: "month"
104+
default:
105+
throw SparkConnectError.InvalidTypeException
106+
}
107+
}
108+
109+
func toString() throws -> String {
110+
let startFieldName = try fieldToString(self.startField)
111+
let endFieldName = try fieldToString(self.endField)
112+
let interval = if startFieldName == endFieldName {
113+
"interval \(startFieldName)"
114+
} else if startFieldName < endFieldName {
115+
"interval \(startFieldName) to \(endFieldName)"
116+
} else {
117+
throw SparkConnectError.InvalidTypeException
118+
}
119+
return interval
120+
}
121+
}
122+
123+
extension DayTimeInterval {
124+
func fieldToString(_ field: Int32) throws -> String {
125+
return switch field {
126+
case 0: "day"
127+
case 1: "hour"
128+
case 2: "minute"
129+
case 3: "second"
130+
default:
131+
throw SparkConnectError.InvalidTypeException
132+
}
133+
}
134+
135+
func toString() throws -> String {
136+
let startFieldName = try fieldToString(self.startField)
137+
let endFieldName = try fieldToString(self.endField)
138+
let interval = if startFieldName == endFieldName {
139+
"interval \(startFieldName)"
140+
} else if startFieldName < endFieldName {
141+
"interval \(startFieldName) to \(endFieldName)"
142+
} else {
143+
throw SparkConnectError.InvalidTypeException
144+
}
145+
return interval
146+
}
147+
}
148+
149+
extension MapType {
150+
func toString() throws -> String {
151+
return "map<\(try self.keyType.simpleString),\(try self.valueType.simpleString)>"
152+
}
153+
}
154+
155+
extension StructType {
156+
func toString() throws -> String {
157+
let fieldTypes = try fields.map { "\($0.name):\(try $0.dataType.simpleString)" }
158+
return "struct<\(fieldTypes.joined(separator: ","))>"
159+
}
160+
}
161+
162+
extension DataType {
163+
var simpleString: String {
164+
get throws {
165+
return switch self.kind {
166+
case .null:
167+
"void"
168+
case .binary:
169+
"binary"
170+
case .boolean:
171+
"boolean"
172+
case .byte:
173+
"tinyint"
174+
case .short:
175+
"smallint"
176+
case .integer:
177+
"int"
178+
case .long:
179+
"bigint"
180+
case .float:
181+
"float"
182+
case .double:
183+
"double"
184+
case .decimal:
185+
"decimal(\(self.decimal.precision),\(self.decimal.scale))"
186+
case .string:
187+
"string"
188+
case .char:
189+
"char"
190+
case .varChar:
191+
"varchar"
192+
case .date:
193+
"date"
194+
case .timestamp:
195+
"timestamp"
196+
case .timestampNtz:
197+
"timestamp_ntz"
198+
case .calendarInterval:
199+
"interval"
200+
case .yearMonthInterval:
201+
try self.yearMonthInterval.toString()
202+
case .dayTimeInterval:
203+
try self.dayTimeInterval.toString()
204+
case .array:
205+
"array<\(try self.array.elementType.simpleString)>"
206+
case .struct:
207+
try self.struct.toString()
208+
case .map:
209+
try self.map.toString()
210+
case .variant:
211+
"variant"
212+
case .udt:
213+
self.udt.type
214+
case .unparsed:
215+
self.unparsed.dataTypeString
216+
default:
217+
throw SparkConnectError.InvalidTypeException
218+
}
219+
}
220+
}
221+
}

Sources/SparkConnect/SparkConnectError.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
enum SparkConnectError: Error {
2222
case UnsupportedOperationException
2323
case InvalidSessionIDException
24+
case InvalidTypeException
2425
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
2121
typealias ConfigRequest = Spark_Connect_ConfigRequest
2222
typealias DataSource = Spark_Connect_Read.DataSource
2323
typealias DataType = Spark_Connect_DataType
24+
typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval
2425
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
2526
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
2627
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
2728
typealias Filter = Spark_Connect_Filter
2829
typealias KeyValue = Spark_Connect_KeyValue
2930
typealias Limit = Spark_Connect_Limit
31+
typealias MapType = Spark_Connect_DataType.Map
3032
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
3133
typealias Plan = Spark_Connect_Plan
3234
typealias Project = Spark_Connect_Project
@@ -35,5 +37,7 @@ typealias Read = Spark_Connect_Read
3537
typealias Relation = Spark_Connect_Relation
3638
typealias SparkConnectService = Spark_Connect_SparkConnectService
3739
typealias Sort = Spark_Connect_Sort
40+
typealias StructType = Spark_Connect_DataType.Struct
3841
typealias UserContext = Spark_Connect_UserContext
3942
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
43+
typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,37 @@ struct DataFrameTests {
7878
await spark.stop()
7979
}
8080

81+
@Test
82+
func dtypes() async throws {
83+
let spark = try await SparkSession.builder.getOrCreate()
84+
let expected = [
85+
("null", "void"),
86+
("127Y", "tinyint"),
87+
("32767S", "smallint"),
88+
("2147483647", "int"),
89+
("9223372036854775807L", "bigint"),
90+
("1.0F", "float"),
91+
("1.0D", "double"),
92+
("1.23", "decimal(3,2)"),
93+
("binary('abc')", "binary"),
94+
("true", "boolean"),
95+
("'abc'", "string"),
96+
("INTERVAL 1 YEAR", "interval year"),
97+
("INTERVAL 1 MONTH", "interval month"),
98+
("INTERVAL 1 DAY", "interval day"),
99+
("INTERVAL 1 HOUR", "interval hour"),
100+
("INTERVAL 1 MINUTE", "interval minute"),
101+
("INTERVAL 1 SECOND", "interval second"),
102+
("array(1, 2, 3)", "array<int>"),
103+
("struct(1, 'a')", "struct<col1:int,col2:string>"),
104+
("map('language', 'Swift')", "map<string,string>"),
105+
]
106+
for pair in expected {
107+
#expect(try await spark.sql("SELECT \(pair.0)").dtypes[0].1 == pair.1)
108+
}
109+
await spark.stop()
110+
}
111+
81112
@Test
82113
func explain() async throws {
83114
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)