From 276c1ac80608db4e08351f4a2af1d0f4c7401546 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 21 Apr 2025 11:02:20 +0900 Subject: [PATCH] [SPARK-51853] Improve `show` to support all signatures --- Sources/SparkConnect/DataFrame.swift | 61 +++++++------ Sources/SparkConnect/SparkConnectClient.swift | 15 ++++ Sources/SparkConnect/TypeAliases.swift | 1 + .../DataFrameInternalTests.swift | 86 +++++++++++++++++++ Tests/SparkConnectTests/DataFrameTests.swift | 12 +++ 5 files changed, 148 insertions(+), 27 deletions(-) create mode 100644 Tests/SparkConnectTests/DataFrameInternalTests.swift diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index e0506cc..f9156a2 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -226,35 +226,42 @@ public actor DataFrame: Sendable { return result } - /// Execute the plan and show the result. + /// Displays the top 20 rows of ``DataFrame`` in a tabular form. public func show() async throws { - try await execute() + try await show(20) + } - if let schema = self._schema { - var columns: [TextTableColumn] = [] - for f in schema.struct.fields { - columns.append(TextTableColumn(header: f.name)) - } - var table = TextTable(columns: columns) - for batch in self.batches { - for i in 0.. DataFrame { + let plan = SparkConnectClient.getShowString(self.plan.root, numRows, truncate, vertical) + return DataFrame(spark: self.spark, plan: plan) } /// Projects a set of expressions and returns a new ``DataFrame``. diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 26f34e6..64f43a8 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -317,6 +317,21 @@ public actor SparkConnectClient { }) } + static func getShowString( + _ child: Relation, _ numRows: Int32, _ truncate: Int32 = 0, _ vertical: Bool = false + ) -> Plan { + var showString = ShowString() + showString.input = child + showString.numRows = numRows + showString.truncate = truncate + showString.vertical = vertical + var relation = Relation() + relation.showString = showString + var plan = Plan() + plan.opType = .root(relation) + return plan + } + func getTreeString(_ sessionID: String, _ plan: Plan, _ level: Int32) async -> AnalyzePlanRequest { return analyze( diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 2ae0636..34632ce 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -44,6 +44,7 @@ typealias Sample = Spark_Connect_Sample typealias SaveMode = Spark_Connect_WriteOperation.SaveMode typealias SetOperation = Spark_Connect_SetOperation typealias SetOpType = SetOperation.SetOpType +typealias ShowString = Spark_Connect_ShowString typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort typealias StructType = Spark_Connect_DataType.Struct diff --git a/Tests/SparkConnectTests/DataFrameInternalTests.swift b/Tests/SparkConnectTests/DataFrameInternalTests.swift new file mode 100644 index 0000000..49814aa --- /dev/null +++ b/Tests/SparkConnectTests/DataFrameInternalTests.swift @@ -0,0 +1,86 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Testing + +@testable import SparkConnect + +/// A test suite for `DataFrame` internal APIs +struct DataFrameInternalTests { + +#if !os(Linux) + @Test + func showString() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(10).showString(2, 0, false).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + #expect( + try rows[0].get(0) as! String == """ + +---+ + |id | + +---+ + |0 | + |1 | + +---+ + only showing top 2 rows + """) + await spark.stop() + } + + @Test + func showStringTruncate() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')") + .showString(2, 2, false).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + print(try rows[0].get(0) as! String) + #expect( + try rows[0].get(0) as! String == """ + +----+----+ + |col1|col2| + +----+----+ + | ab| de| + | gh| jk| + +----+----+ + + """) + await spark.stop() + } + + @Test + func showStringVertical() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(10).showString(2, 0, true).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + print(try rows[0].get(0) as! String) + #expect( + try rows[0].get(0) as! String == """ + -RECORD 0-- + id | 0 + -RECORD 1-- + id | 1 + only showing top 2 rows + """) + await spark.stop() + } +#endif +} diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 45c4117..ba53923 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -348,6 +348,18 @@ struct DataFrameTests { try await spark.sql("SELECT * FROM VALUES (true, false)").show() try await spark.sql("SELECT * FROM VALUES (1, 2)").show() try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')").show() + + // Check all signatures + try await spark.range(1000).show() + try await spark.range(1000).show(1) + try await spark.range(1000).show(true) + try await spark.range(1000).show(false) + try await spark.range(1000).show(1, true) + try await spark.range(1000).show(1, false) + try await spark.range(1000).show(1, 20) + try await spark.range(1000).show(1, 20, true) + try await spark.range(1000).show(1, 20, false) + await spark.stop() }