Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ public actor DataFrame: Sendable {
return try self.schema!.jsonString()
}

var dtypes: [(String, String)] {
get async throws {
try await analyzePlanIfNeeded()
return try self.schema!.struct.fields.map { ($0.name, try $0.dataType.simpleString) }
}
}

private func analyzePlanIfNeeded() async throws {
if self.schema != nil {
return
Expand Down
124 changes: 124 additions & 0 deletions Sources/SparkConnect/Extension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,127 @@ extension SparkSession: Equatable {
return lhs.sessionID == rhs.sessionID
}
}

extension YearMonthInterval {
func fieldToString(_ field: Int32) throws -> String {
return switch field {
case 0: "year"
case 1: "month"
default:
throw SparkConnectError.InvalidTypeException
}
}

func toString() throws -> String {
let startFieldName = try fieldToString(self.startField)
let endFieldName = try fieldToString(self.endField)
let interval = if startFieldName == endFieldName {
"interval \(startFieldName)"
} else if startFieldName < endFieldName {
"interval \(startFieldName) to \(endFieldName)"
} else {
throw SparkConnectError.InvalidTypeException
}
return interval
}
}

extension DayTimeInterval {
func fieldToString(_ field: Int32) throws -> String {
return switch field {
case 0: "day"
case 1: "hour"
case 2: "minute"
case 3: "second"
default:
throw SparkConnectError.InvalidTypeException
}
}

func toString() throws -> String {
let startFieldName = try fieldToString(self.startField)
let endFieldName = try fieldToString(self.endField)
let interval = if startFieldName == endFieldName {
"interval \(startFieldName)"
} else if startFieldName < endFieldName {
"interval \(startFieldName) to \(endFieldName)"
} else {
throw SparkConnectError.InvalidTypeException
}
return interval
}
}

extension MapType {
func toString() throws -> String {
return "map<\(try self.keyType.simpleString),\(try self.valueType.simpleString)>"
}
}

extension StructType {
func toString() throws -> String {
let fieldTypes = try fields.map { "\($0.name):\(try $0.dataType.simpleString)" }
return "struct<\(fieldTypes.joined(separator: ","))>"
}
}

extension DataType {
var simpleString: String {
get throws {
return switch self.kind {
case .null:
"void"
case .binary:
"binary"
case .boolean:
"boolean"
case .byte:
"tinyint"
case .short:
"smallint"
case .integer:
"int"
case .long:
"bigint"
case .float:
"float"
case .double:
"double"
case .decimal:
"decimal(\(self.decimal.precision),\(self.decimal.scale))"
case .string:
"string"
case .char:
"char"
case .varChar:
"varchar"
case .date:
"date"
case .timestamp:
"timestamp"
case .timestampNtz:
"timestamp_ntz"
case .calendarInterval:
"interval"
case .yearMonthInterval:
try self.yearMonthInterval.toString()
case .dayTimeInterval:
try self.dayTimeInterval.toString()
case .array:
"array<\(try self.array.elementType.simpleString)>"
case .struct:
try self.struct.toString()
case .map:
try self.map.toString()
case .variant:
"variant"
case .udt:
self.udt.type
case .unparsed:
self.unparsed.dataTypeString
default:
throw SparkConnectError.InvalidTypeException
}
}
}
}
1 change: 1 addition & 0 deletions Sources/SparkConnect/SparkConnectError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
enum SparkConnectError: Error {
case UnsupportedOperationException
case InvalidSessionIDException
case InvalidTypeException
}
4 changes: 4 additions & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
typealias ConfigRequest = Spark_Connect_ConfigRequest
typealias DataSource = Spark_Connect_Read.DataSource
typealias DataType = Spark_Connect_DataType
typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
typealias Filter = Spark_Connect_Filter
typealias KeyValue = Spark_Connect_KeyValue
typealias Limit = Spark_Connect_Limit
typealias MapType = Spark_Connect_DataType.Map
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
typealias Plan = Spark_Connect_Plan
typealias Project = Spark_Connect_Project
Expand All @@ -35,5 +37,7 @@ typealias Read = Spark_Connect_Read
typealias Relation = Spark_Connect_Relation
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
typealias StructType = Spark_Connect_DataType.Struct
typealias UserContext = Spark_Connect_UserContext
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval
31 changes: 31 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,37 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func dtypes() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let expected = [
("null", "void"),
("127Y", "tinyint"),
("32767S", "smallint"),
("2147483647", "int"),
("9223372036854775807L", "bigint"),
("1.0F", "float"),
("1.0D", "double"),
("1.23", "decimal(3,2)"),
("binary('abc')", "binary"),
("true", "boolean"),
("'abc'", "string"),
("INTERVAL 1 YEAR", "interval year"),
("INTERVAL 1 MONTH", "interval month"),
("INTERVAL 1 DAY", "interval day"),
("INTERVAL 1 HOUR", "interval hour"),
("INTERVAL 1 MINUTE", "interval minute"),
("INTERVAL 1 SECOND", "interval second"),
("array(1, 2, 3)", "array<int>"),
("struct(1, 'a')", "struct<col1:int,col2:string>"),
("map('language', 'Swift')", "map<string,string>"),
]
for pair in expected {
#expect(try await spark.sql("SELECT \(pair.0)").dtypes[0].1 == pair.1)
}
await spark.stop()
}

@Test
func explain() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading