diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index b1ec758..9e21ba6 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -355,4 +355,24 @@ public actor DataFrame: Sendable { print(response.explain.explainString) } } + + /// Prints the schema to the console in a nice tree format. + public func printSchema() async throws { + try await printSchema(Int32.max) + } + + /// Prints the schema up to the given level to the console in a nice tree format. + /// - Parameter level: A level to be printed. + public func printSchema(_ level: Int32) async throws { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + let response = try await service.analyzePlan(spark.client.getTreeString(spark.sessionID, plan, level)) + print(response.treeString.treeString) + } + } } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 4d28c34..99e0c11 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -294,6 +294,18 @@ public actor SparkConnectClient { }) } + func getTreeString(_ sessionID: String, _ plan: Plan, _ level: Int32) async -> AnalyzePlanRequest + { + return analyze( + sessionID, + { + var treeString = AnalyzePlanRequest.TreeString() + treeString.plan = plan + treeString.level = level + return OneOf_Analyze.treeString(treeString) + }) + } + static func getProject(_ child: Relation, _ cols: [String]) -> Plan { var project = Project() project.input = child diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 87b8fa4..ec15e43 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -70,6 +70,14 @@ struct DataFrameTests { await spark.stop() } + @Test + func printSchema() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql("SELECT struct(1, 2)").printSchema() + try await spark.sql("SELECT struct(1, 2)").printSchema(1) + await spark.stop() + } + @Test func explain() async throws { let spark = try await SparkSession.builder.getOrCreate()