diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 0d7546b..8a61d86 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -17,6 +17,7 @@ // under the License. // +import Dispatch import Foundation import GRPCCore import GRPCNIOTransportHTTP2 @@ -116,12 +117,26 @@ public actor SparkSession { return try await DataFrame(spark: self, sqlText: sqlText) } + /// Returns a ``DataFrameReader`` that can be used to read non-streaming data in as a + /// `DataFrame` var read: DataFrameReader { get { return DataFrameReader(sparkSession: self) } } + /// Executes some code block and prints to stdout the time taken to execute the block. + /// - Parameter f: A function to execute. + /// - Returns: The result of the executed code. + public func time(_ f: () async throws -> T) async throws -> T { + let start = DispatchTime.now() + let ret = try await f() + let end = DispatchTime.now() + let elapsed = (end.uptimeNanoseconds - start.uptimeNanoseconds) / 1_000_000 + print("Time taken: \(elapsed) ms") + return ret + } + /// This is defined as the return type of `SparkSession.sparkContext` method. /// This is an empty `Struct` type because `sparkContext` method is designed to throw /// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`. diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 4a2a549..f302349 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -74,4 +74,15 @@ struct SparkSessionTests { #expect(try await spark.range(0, 100, 2).count() == 50) await spark.stop() } + + @Test + func time() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.time(spark.range(1000).count) == 1000) +#if !os(Linux) + #expect(try await spark.time(spark.range(1).collect) == [["0"]]) + try await spark.time(spark.range(10).show) +#endif + await spark.stop() + } }