From 32b5e857898b20b28f551692526d826bbfdaf5d2 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 1 Apr 2025 18:36:01 +0900 Subject: [PATCH] [SPARK-51679] Support `dtypes` for `DataFrame` --- Sources/SparkConnect/DataFrame.swift | 7 ++ Sources/SparkConnect/Extension.swift | 124 +++++++++++++++++++ Sources/SparkConnect/SparkConnectError.swift | 1 + Sources/SparkConnect/TypeAliases.swift | 4 + Tests/SparkConnectTests/DataFrameTests.swift | 31 +++++ 5 files changed, 167 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 9e21ba6..2588ef7 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -95,6 +95,13 @@ public actor DataFrame: Sendable { return try self.schema!.jsonString() } + var dtypes: [(String, String)] { + get async throws { + try await analyzePlanIfNeeded() + return try self.schema!.struct.fields.map { ($0.name, try $0.dataType.simpleString) } + } + } + private func analyzePlanIfNeeded() async throws { if self.schema != nil { return diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 1d470fe..7fdbaee 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -95,3 +95,127 @@ extension SparkSession: Equatable { return lhs.sessionID == rhs.sessionID } } + +extension YearMonthInterval { + func fieldToString(_ field: Int32) throws -> String { + return switch field { + case 0: "year" + case 1: "month" + default: + throw SparkConnectError.InvalidTypeException + } + } + + func toString() throws -> String { + let startFieldName = try fieldToString(self.startField) + let endFieldName = try fieldToString(self.endField) + let interval = if startFieldName == endFieldName { + "interval \(startFieldName)" + } else if startFieldName < endFieldName { + "interval \(startFieldName) to \(endFieldName)" + } else { + throw SparkConnectError.InvalidTypeException + } + return interval + } +} + +extension DayTimeInterval { + func fieldToString(_ field: Int32) throws -> String { + return switch field { + case 0: "day" + case 1: "hour" + case 2: "minute" + case 3: "second" + default: + throw SparkConnectError.InvalidTypeException + } + } + + func toString() throws -> String { + let startFieldName = try fieldToString(self.startField) + let endFieldName = try fieldToString(self.endField) + let interval = if startFieldName == endFieldName { + "interval \(startFieldName)" + } else if startFieldName < endFieldName { + "interval \(startFieldName) to \(endFieldName)" + } else { + throw SparkConnectError.InvalidTypeException + } + return interval + } +} + +extension MapType { + func toString() throws -> String { + return "map<\(try self.keyType.simpleString),\(try self.valueType.simpleString)>" + } +} + +extension StructType { + func toString() throws -> String { + let fieldTypes = try fields.map { "\($0.name):\(try $0.dataType.simpleString)" } + return "struct<\(fieldTypes.joined(separator: ","))>" + } +} + +extension DataType { + var simpleString: String { + get throws { + return switch self.kind { + case .null: + "void" + case .binary: + "binary" + case .boolean: + "boolean" + case .byte: + "tinyint" + case .short: + "smallint" + case .integer: + "int" + case .long: + "bigint" + case .float: + "float" + case .double: + "double" + case .decimal: + "decimal(\(self.decimal.precision),\(self.decimal.scale))" + case .string: + "string" + case .char: + "char" + case .varChar: + "varchar" + case .date: + "date" + case .timestamp: + "timestamp" + case .timestampNtz: + "timestamp_ntz" + case .calendarInterval: + "interval" + case .yearMonthInterval: + try self.yearMonthInterval.toString() + case .dayTimeInterval: + try self.dayTimeInterval.toString() + case .array: + "array<\(try self.array.elementType.simpleString)>" + case .struct: + try self.struct.toString() + case .map: + try self.map.toString() + case .variant: + "variant" + case .udt: + self.udt.type + case .unparsed: + self.unparsed.dataTypeString + default: + throw SparkConnectError.InvalidTypeException + } + } + } +} diff --git a/Sources/SparkConnect/SparkConnectError.swift b/Sources/SparkConnect/SparkConnectError.swift index e88c061..97407f4 100644 --- a/Sources/SparkConnect/SparkConnectError.swift +++ b/Sources/SparkConnect/SparkConnectError.swift @@ -21,4 +21,5 @@ enum SparkConnectError: Error { case UnsupportedOperationException case InvalidSessionIDException + case InvalidTypeException } diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index aa1e087..ff8fd11 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -21,12 +21,14 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse typealias ConfigRequest = Spark_Connect_ConfigRequest typealias DataSource = Spark_Connect_Read.DataSource typealias DataType = Spark_Connect_DataType +typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode typealias ExpressionString = Spark_Connect_Expression.ExpressionString typealias Filter = Spark_Connect_Filter typealias KeyValue = Spark_Connect_KeyValue typealias Limit = Spark_Connect_Limit +typealias MapType = Spark_Connect_DataType.Map typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze typealias Plan = Spark_Connect_Plan typealias Project = Spark_Connect_Project @@ -35,5 +37,7 @@ typealias Read = Spark_Connect_Read typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort +typealias StructType = Spark_Connect_DataType.Struct typealias UserContext = Spark_Connect_UserContext typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute +typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index ec15e43..f9dd37e 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -78,6 +78,37 @@ struct DataFrameTests { await spark.stop() } + @Test + func dtypes() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = [ + ("null", "void"), + ("127Y", "tinyint"), + ("32767S", "smallint"), + ("2147483647", "int"), + ("9223372036854775807L", "bigint"), + ("1.0F", "float"), + ("1.0D", "double"), + ("1.23", "decimal(3,2)"), + ("binary('abc')", "binary"), + ("true", "boolean"), + ("'abc'", "string"), + ("INTERVAL 1 YEAR", "interval year"), + ("INTERVAL 1 MONTH", "interval month"), + ("INTERVAL 1 DAY", "interval day"), + ("INTERVAL 1 HOUR", "interval hour"), + ("INTERVAL 1 MINUTE", "interval minute"), + ("INTERVAL 1 SECOND", "interval second"), + ("array(1, 2, 3)", "array"), + ("struct(1, 'a')", "struct"), + ("map('language', 'Swift')", "map"), + ] + for pair in expected { + #expect(try await spark.sql("SELECT \(pair.0)").dtypes[0].1 == pair.1) + } + await spark.stop() + } + @Test func explain() async throws { let spark = try await SparkSession.builder.getOrCreate()