Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -226,35 +226,42 @@ public actor DataFrame: Sendable {
return result
}

/// Execute the plan and show the result.
/// Displays the top 20 rows of ``DataFrame`` in a tabular form.
public func show() async throws {
try await execute()
try await show(20)
}

if let schema = self._schema {
var columns: [TextTableColumn] = []
for f in schema.struct.fields {
columns.append(TextTableColumn(header: f.name))
}
var table = TextTable(columns: columns)
for batch in self.batches {
for i in 0..<batch.length {
var values: [String] = []
for column in batch.columns {
let str = column.array as! AsString
if column.data.isNull(i) {
values.append("NULL")
} else if column.data.type.info == ArrowType.ArrowBinary {
let binary = str.asString(i).utf8.map { String(format: "%02x", $0) }.joined(separator: " ")
values.append("[\(binary)]")
} else {
values.append(str.asString(i))
}
}
table.addRow(values: values)
}
}
print(table.render())
}
/// Displays the top 20 rows of ``DataFrame`` in a tabular form.
/// - Parameter truncate: Whether truncate long strings. If true, strings more than 20 characters will be truncated
/// and all cells will be aligned right
public func show(_ truncate: Bool) async throws {
try await show(20, truncate)
}

/// Displays the ``DataFrame`` in a tabular form.
/// - Parameters:
/// - numRows: Number of rows to show
/// - truncate: Whether truncate long strings. If true, strings more than 20 characters will be truncated
/// and all cells will be aligned right
public func show(_ numRows: Int32 = 20, _ truncate: Bool = true) async throws {
try await show(numRows, truncate ? 20 : 0)
}

/// Displays the ``DataFrame`` in a tabular form.
/// - Parameters:
/// - numRows: Number of rows to show
/// - truncate: If set to more than 0, truncates strings to `truncate` characters and all cells will be aligned right.
/// - vertical: If set to true, prints output rows vertically (one line per column value).
public func show(_ numRows: Int32, _ truncate: Int32, _ vertical: Bool = false) async throws {
let rows = try await showString(numRows, truncate, vertical).collect()
assert(rows.count == 1)
assert(rows[0].length == 1)
print(try rows[0].get(0) as! String)
}

func showString(_ numRows: Int32, _ truncate: Int32, _ vertical: Bool) -> DataFrame {
let plan = SparkConnectClient.getShowString(self.plan.root, numRows, truncate, vertical)
return DataFrame(spark: self.spark, plan: plan)
}

/// Projects a set of expressions and returns a new ``DataFrame``.
Expand Down
15 changes: 15 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,21 @@ public actor SparkConnectClient {
})
}

static func getShowString(
_ child: Relation, _ numRows: Int32, _ truncate: Int32 = 0, _ vertical: Bool = false
) -> Plan {
var showString = ShowString()
showString.input = child
showString.numRows = numRows
showString.truncate = truncate
showString.vertical = vertical
var relation = Relation()
relation.showString = showString
var plan = Plan()
plan.opType = .root(relation)
return plan
}

func getTreeString(_ sessionID: String, _ plan: Plan, _ level: Int32) async -> AnalyzePlanRequest
{
return analyze(
Expand Down
1 change: 1 addition & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ typealias Sample = Spark_Connect_Sample
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
typealias SetOperation = Spark_Connect_SetOperation
typealias SetOpType = SetOperation.SetOpType
typealias ShowString = Spark_Connect_ShowString
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
typealias StructType = Spark_Connect_DataType.Struct
Expand Down
86 changes: 86 additions & 0 deletions Tests/SparkConnectTests/DataFrameInternalTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//

import Testing

@testable import SparkConnect

/// A test suite for `DataFrame` internal APIs
struct DataFrameInternalTests {

#if !os(Linux)
@Test
func showString() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let rows = try await spark.range(10).showString(2, 0, false).collect()
#expect(rows.count == 1)
#expect(rows[0].length == 1)
#expect(
try rows[0].get(0) as! String == """
+---+
|id |
+---+
|0 |
|1 |
+---+
only showing top 2 rows
""")
await spark.stop()
}

@Test
func showStringTruncate() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')")
.showString(2, 2, false).collect()
#expect(rows.count == 1)
#expect(rows[0].length == 1)
print(try rows[0].get(0) as! String)
#expect(
try rows[0].get(0) as! String == """
+----+----+
|col1|col2|
+----+----+
| ab| de|
| gh| jk|
+----+----+

""")
await spark.stop()
}

@Test
func showStringVertical() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let rows = try await spark.range(10).showString(2, 0, true).collect()
#expect(rows.count == 1)
#expect(rows[0].length == 1)
print(try rows[0].get(0) as! String)
#expect(
try rows[0].get(0) as! String == """
-RECORD 0--
id | 0
-RECORD 1--
id | 1
only showing top 2 rows
""")
await spark.stop()
}
#endif
}
12 changes: 12 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ struct DataFrameTests {
try await spark.sql("SELECT * FROM VALUES (true, false)").show()
try await spark.sql("SELECT * FROM VALUES (1, 2)").show()
try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')").show()

// Check all signatures
try await spark.range(1000).show()
try await spark.range(1000).show(1)
try await spark.range(1000).show(true)
try await spark.range(1000).show(false)
try await spark.range(1000).show(1, true)
try await spark.range(1000).show(1, false)
try await spark.range(1000).show(1, 20)
try await spark.range(1000).show(1, 20, true)
try await spark.range(1000).show(1, 20, false)

await spark.stop()
}

Expand Down
Loading