diff --git a/Tests/SparkConnectTests/CatalogTests.swift b/Tests/SparkConnectTests/CatalogTests.swift index 24ae1f6..c631671 100644 --- a/Tests/SparkConnectTests/CatalogTests.swift +++ b/Tests/SparkConnectTests/CatalogTests.swift @@ -25,297 +25,295 @@ import Testing /// A test suite for `Catalog` @Suite(.serialized) struct CatalogTests { - #if !os(Linux) - @Test - func currentCatalog() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.currentCatalog() == "spark_catalog") - await spark.stop() - } + @Test + func currentCatalog() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.currentCatalog() == "spark_catalog") + await spark.stop() + } - @Test - func setCurrentCatalog() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.catalog.setCurrentCatalog("spark_catalog") - if await spark.version >= "4.0.0" { - try await #require(throws: SparkConnectError.CatalogNotFound) { - try await spark.catalog.setCurrentCatalog("not_exist_catalog") - } - } else { - try await #require(throws: Error.self) { - try await spark.catalog.setCurrentCatalog("not_exist_catalog") - } + @Test + func setCurrentCatalog() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.catalog.setCurrentCatalog("spark_catalog") + if await spark.version >= "4.0.0" { + try await #require(throws: SparkConnectError.CatalogNotFound) { + try await spark.catalog.setCurrentCatalog("not_exist_catalog") + } + } else { + try await #require(throws: Error.self) { + try await spark.catalog.setCurrentCatalog("not_exist_catalog") } - await spark.stop() } + await spark.stop() + } - @Test - func listCatalogs() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")]) - #expect( - try await spark.catalog.listCatalogs(pattern: "*") == [ - CatalogMetadata(name: "spark_catalog") - ]) - #expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0) - await spark.stop() - } + @Test + func listCatalogs() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")]) + #expect( + try await spark.catalog.listCatalogs(pattern: "*") == [ + CatalogMetadata(name: "spark_catalog") + ]) + #expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0) + await spark.stop() + } - @Test - func currentDatabase() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.currentDatabase() == "default") - await spark.stop() - } + @Test + func currentDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.currentDatabase() == "default") + await spark.stop() + } - @Test - func setCurrentDatabase() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.catalog.setCurrentDatabase("default") - try await #require(throws: SparkConnectError.SchemaNotFound) { - try await spark.catalog.setCurrentDatabase("not_exist_database") - } - await spark.stop() + @Test + func setCurrentDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.catalog.setCurrentDatabase("default") + try await #require(throws: SparkConnectError.SchemaNotFound) { + try await spark.catalog.setCurrentDatabase("not_exist_database") } + await spark.stop() + } - @Test - func listDatabases() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let dbs = try await spark.catalog.listDatabases() - #expect(dbs.count == 1) - #expect(dbs[0].name == "default") - #expect(dbs[0].catalog == "spark_catalog") - #expect(dbs[0].description == "default database") - #expect(dbs[0].locationUri.hasSuffix("spark-warehouse")) - #expect(try await spark.catalog.listDatabases(pattern: "*") == dbs) - #expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0) - await spark.stop() - } + @Test + func listDatabases() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let dbs = try await spark.catalog.listDatabases() + #expect(dbs.count == 1) + #expect(dbs[0].name == "default") + #expect(dbs[0].catalog == "spark_catalog") + #expect(dbs[0].description == "default database") + #expect(dbs[0].locationUri.hasSuffix("spark-warehouse")) + #expect(try await spark.catalog.listDatabases(pattern: "*") == dbs) + #expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0) + await spark.stop() + } - @Test - func getDatabase() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let db = try await spark.catalog.getDatabase("default") - #expect(db.name == "default") - #expect(db.catalog == "spark_catalog") - #expect(db.description == "default database") - #expect(db.locationUri.hasSuffix("spark-warehouse")) - try await #require(throws: SparkConnectError.SchemaNotFound) { - try await spark.catalog.getDatabase("not_exist_database") - } - await spark.stop() + @Test + func getDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let db = try await spark.catalog.getDatabase("default") + #expect(db.name == "default") + #expect(db.catalog == "spark_catalog") + #expect(db.description == "default database") + #expect(db.locationUri.hasSuffix("spark-warehouse")) + try await #require(throws: SparkConnectError.SchemaNotFound) { + try await spark.catalog.getDatabase("not_exist_database") } + await spark.stop() + } - @Test - func databaseExists() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.databaseExists("default")) - - let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - #expect(try await spark.catalog.databaseExists(dbName) == false) - try await SQLHelper.withDatabase(spark, dbName)({ - try await spark.sql("CREATE DATABASE \(dbName)").count() - #expect(try await spark.catalog.databaseExists(dbName)) - }) - #expect(try await spark.catalog.databaseExists(dbName) == false) - await spark.stop() - } + @Test + func databaseExists() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.databaseExists("default")) - @Test - func createTable() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTable(spark, tableName)({ - try await spark.range(1).write.orc("/tmp/\(tableName)") - #expect( - try await spark.catalog.createTable(tableName, "/tmp/\(tableName)", source: "orc").count() - == 1) - #expect(try await spark.catalog.tableExists(tableName)) - }) - await spark.stop() - } + let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + #expect(try await spark.catalog.databaseExists(dbName) == false) + try await SQLHelper.withDatabase(spark, dbName)({ + try await spark.sql("CREATE DATABASE \(dbName)").count() + #expect(try await spark.catalog.databaseExists(dbName)) + }) + #expect(try await spark.catalog.databaseExists(dbName) == false) + await spark.stop() + } - @Test - func tableExists() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTable(spark, tableName)({ - try await spark.range(1).write.parquet("/tmp/\(tableName)") - #expect(try await spark.catalog.tableExists(tableName) == false) - #expect(try await spark.catalog.createTable(tableName, "/tmp/\(tableName)").count() == 1) - #expect(try await spark.catalog.tableExists(tableName)) - #expect(try await spark.catalog.tableExists("default", tableName)) - #expect(try await spark.catalog.tableExists("default2", tableName) == false) - }) + @Test + func createTable() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(1).write.orc("/tmp/\(tableName)") + #expect( + try await spark.catalog.createTable(tableName, "/tmp/\(tableName)", source: "orc").count() + == 1) + #expect(try await spark.catalog.tableExists(tableName)) + }) + await spark.stop() + } + + @Test + func tableExists() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(1).write.parquet("/tmp/\(tableName)") #expect(try await spark.catalog.tableExists(tableName) == false) + #expect(try await spark.catalog.createTable(tableName, "/tmp/\(tableName)").count() == 1) + #expect(try await spark.catalog.tableExists(tableName)) + #expect(try await spark.catalog.tableExists("default", tableName)) + #expect(try await spark.catalog.tableExists("default2", tableName) == false) + }) + #expect(try await spark.catalog.tableExists(tableName) == false) - try await #require(throws: SparkConnectError.ParseSyntaxError) { - try await spark.catalog.tableExists("invalid table name") - } - await spark.stop() + try await #require(throws: SparkConnectError.ParseSyntaxError) { + try await spark.catalog.tableExists("invalid table name") } + await spark.stop() + } - @Test - func listColumns() async throws { - let spark = try await SparkSession.builder.getOrCreate() - - // Table - let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - let path = "/tmp/\(tableName)" - try await SQLHelper.withTable(spark, tableName)({ - try await spark.range(2).write.orc(path) - let expected = - if await spark.version.starts(with: "4.") { - [Row("id", nil, "bigint", true, false, false, false)] - } else { - [Row("id", nil, "bigint", true, false, false)] - } - #expect(try await spark.catalog.createTable(tableName, path, source: "orc").count() == 2) - #expect(try await spark.catalog.listColumns(tableName).collect() == expected) - #expect(try await spark.catalog.listColumns("default.\(tableName)").collect() == expected) - }) - - // View - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - try await spark.range(1).createTempView(viewName) - let expected = - if await spark.version.starts(with: "4.") { - [Row("id", nil, "bigint", false, false, false, false)] - } else { - [Row("id", nil, "bigint", false, false, false)] - } - #expect(try await spark.catalog.listColumns(viewName).collect() == expected) - }) - - await spark.stop() - } + @Test + func listColumns() async throws { + let spark = try await SparkSession.builder.getOrCreate() - @Test - func functionExists() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.functionExists("base64")) - #expect(try await spark.catalog.functionExists("non_exist_function") == false) + // Table + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + let path = "/tmp/\(tableName)" + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(2).write.orc(path) + let expected = + if await spark.version.starts(with: "4.") { + [Row("id", nil, "bigint", true, false, false, false)] + } else { + [Row("id", nil, "bigint", true, false, false)] + } + #expect(try await spark.catalog.createTable(tableName, path, source: "orc").count() == 2) + #expect(try await spark.catalog.listColumns(tableName).collect() == expected) + #expect(try await spark.catalog.listColumns("default.\(tableName)").collect() == expected) + }) - try await #require(throws: SparkConnectError.ParseSyntaxError) { - try await spark.catalog.functionExists("invalid function name") - } - await spark.stop() - } + // View + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + try await spark.range(1).createTempView(viewName) + let expected = + if await spark.version.starts(with: "4.") { + [Row("id", nil, "bigint", false, false, false, false)] + } else { + [Row("id", nil, "bigint", false, false, false)] + } + #expect(try await spark.catalog.listColumns(viewName).collect() == expected) + }) - @Test - func createTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createTempView(viewName) - #expect(try await spark.catalog.tableExists(viewName)) + await spark.stop() + } - try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { - try await spark.range(1).createTempView(viewName) - } - }) + @Test + func functionExists() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.functionExists("base64")) + #expect(try await spark.catalog.functionExists("non_exist_function") == false) - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createTempView("invalid view name") + try await #require(throws: SparkConnectError.ParseSyntaxError) { + try await spark.catalog.functionExists("invalid function name") + } + await spark.stop() + } + + @Test + func createTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) + try await spark.range(1).createTempView(viewName) + #expect(try await spark.catalog.tableExists(viewName)) + + try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { + try await spark.range(1).createTempView(viewName) } + }) - await spark.stop() + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createTempView("invalid view name") } - @Test - func createOrReplaceTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createOrReplaceTempView(viewName) - #expect(try await spark.catalog.tableExists(viewName)) - try await spark.range(1).createOrReplaceTempView(viewName) - }) - - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createOrReplaceTempView("invalid view name") - } + await spark.stop() + } + + @Test + func createOrReplaceTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) + try await spark.range(1).createOrReplaceTempView(viewName) + #expect(try await spark.catalog.tableExists(viewName)) + try await spark.range(1).createOrReplaceTempView(viewName) + }) - await spark.stop() + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createOrReplaceTempView("invalid view name") } - @Test - func createGlobalTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withGlobalTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - try await spark.range(1).createGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) + await spark.stop() + } - try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { - try await spark.range(1).createGlobalTempView(viewName) - } - }) + @Test + func createGlobalTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withGlobalTempView(spark, viewName)({ #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) + try await spark.range(1).createGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createGlobalTempView("invalid view name") + try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { + try await spark.range(1).createGlobalTempView(viewName) } + }) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - await spark.stop() + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createGlobalTempView("invalid view name") } - @Test - func createOrReplaceGlobalTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withGlobalTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - try await spark.range(1).createOrReplaceGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) - try await spark.range(1).createOrReplaceGlobalTempView(viewName) - }) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) + await spark.stop() + } - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createOrReplaceGlobalTempView("invalid view name") - } + @Test + func createOrReplaceGlobalTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withGlobalTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) + try await spark.range(1).createOrReplaceGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) + try await spark.range(1).createOrReplaceGlobalTempView(viewName) + }) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - await spark.stop() + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createOrReplaceGlobalTempView("invalid view name") } - @Test - func dropTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createTempView(viewName) - try await spark.catalog.dropTempView(viewName) - #expect(try await spark.catalog.tableExists(viewName) == false) - }) + await spark.stop() + } - #expect(try await spark.catalog.dropTempView("non_exist_view") == false) - #expect(try await spark.catalog.dropTempView("invalid view name") == false) - await spark.stop() - } + @Test + func dropTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) + try await spark.range(1).createTempView(viewName) + try await spark.catalog.dropTempView(viewName) + #expect(try await spark.catalog.tableExists(viewName) == false) + }) - @Test - func dropGlobalTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) - try await spark.catalog.dropGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - }) - - #expect(try await spark.catalog.dropGlobalTempView("non_exist_view") == false) - #expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false) - await spark.stop() - } - #endif + #expect(try await spark.catalog.dropTempView("non_exist_view") == false) + #expect(try await spark.catalog.dropTempView("invalid view name") == false) + await spark.stop() + } + + @Test + func dropGlobalTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) + try await spark.range(1).createGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) + try await spark.catalog.dropGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) + }) + + #expect(try await spark.catalog.dropGlobalTempView("non_exist_view") == false) + #expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false) + await spark.stop() + } @Test func cacheTable() async throws { diff --git a/Tests/SparkConnectTests/DataFrameInternalTests.swift b/Tests/SparkConnectTests/DataFrameInternalTests.swift index 6c843c3..1b79419 100644 --- a/Tests/SparkConnectTests/DataFrameInternalTests.swift +++ b/Tests/SparkConnectTests/DataFrameInternalTests.swift @@ -25,63 +25,61 @@ import Testing @Suite(.serialized) struct DataFrameInternalTests { - #if !os(Linux) - @Test - func showString() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.range(10).showString(2, 0, false).collect() - #expect(rows.count == 1) - #expect(rows[0].length == 1) - #expect( - try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ - +---+ - |id | - +---+ - |0 | - |1 | - +---+ - only showing top 2 rows - """) - await spark.stop() - } + @Test + func showString() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(10).showString(2, 0, false).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + #expect( + try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ + +---+ + |id | + +---+ + |0 | + |1 | + +---+ + only showing top 2 rows + """) + await spark.stop() + } - @Test - func showStringTruncate() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')") - .showString(2, 2, false).collect() - #expect(rows.count == 1) - #expect(rows[0].length == 1) - print(try rows[0].get(0) as! String) - #expect( - try rows[0].get(0) as! String == """ - +----+----+ - |col1|col2| - +----+----+ - | ab| de| - | gh| jk| - +----+----+ + @Test + func showStringTruncate() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')") + .showString(2, 2, false).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + print(try rows[0].get(0) as! String) + #expect( + try rows[0].get(0) as! String == """ + +----+----+ + |col1|col2| + +----+----+ + | ab| de| + | gh| jk| + +----+----+ - """) - await spark.stop() - } + """) + await spark.stop() + } - @Test - func showStringVertical() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.range(10).showString(2, 0, true).collect() - #expect(rows.count == 1) - #expect(rows[0].length == 1) - print(try rows[0].get(0) as! String) - #expect( - try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ - -RECORD 0-- - id | 0 - -RECORD 1-- - id | 1 - only showing top 2 rows - """) - await spark.stop() - } - #endif + @Test + func showStringVertical() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(10).showString(2, 0, true).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + print(try rows[0].get(0) as! String) + #expect( + try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ + -RECORD 0-- + id | 0 + -RECORD 1-- + id | 1 + only showing top 2 rows + """) + await spark.stop() + } } diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 8db57d3..256ce09 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -332,7 +332,7 @@ struct DataFrameTests { @Test func isLocal() async throws { let spark = try await SparkSession.builder.getOrCreate() - if !(await spark.version.starts(with: "4.1")) { // TODO(SPARK-52746) + if !(await spark.version.starts(with: "4.1")) { // TODO(SPARK-52746) #expect(try await spark.sql("SHOW DATABASES").isLocal()) #expect(try await spark.sql("SHOW TABLES").isLocal()) } @@ -347,23 +347,21 @@ struct DataFrameTests { await spark.stop() } - #if !os(Linux) - @Test - func sort() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let expected = Array((1...10).map { Row($0) }) - #expect(try await spark.range(10, 0, -1).sort("id").collect() == expected) - await spark.stop() - } + @Test + func sort() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = Array((1...10).map { Row($0) }) + #expect(try await spark.range(10, 0, -1).sort("id").collect() == expected) + await spark.stop() + } - @Test - func orderBy() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let expected = Array((1...10).map { Row($0) }) - #expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected) - await spark.stop() - } - #endif + @Test + func orderBy() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = Array((1...10).map { Row($0) }) + #expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected) + await spark.stop() + } @Test func table() async throws { @@ -379,167 +377,207 @@ struct DataFrameTests { await spark.stop() } - #if !os(Linux) - @Test - func collect() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).collect().isEmpty) - #expect( - try await spark.sql( - "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" - ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false, "def")]) - await spark.stop() - } + @Test + func collect() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).collect().isEmpty) + #expect( + try await spark.sql( + "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" + ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false, "def")]) + await spark.stop() + } - @Test - func collectMultiple() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1) - #expect(try await df.collect().count == 1) - #expect(try await df.collect().count == 1) - await spark.stop() - } + @Test + func collectMultiple() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1) + #expect(try await df.collect().count == 1) + #expect(try await df.collect().count == 1) + await spark.stop() + } - @Test - func first() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(2).sort("id").first() == Row(0)) - #expect(try await spark.range(2).sort("id").head() == Row(0)) - await spark.stop() - } + @Test + func first() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(2).sort("id").first() == Row(0)) + #expect(try await spark.range(2).sort("id").head() == Row(0)) + await spark.stop() + } - @Test - func head() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).head(1).isEmpty) - #expect(try await spark.range(2).sort("id").head() == Row(0)) - #expect(try await spark.range(2).sort("id").head(1) == [Row(0)]) - #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)]) - #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)]) - await spark.stop() - } + @Test + func head() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).head(1).isEmpty) + #expect(try await spark.range(2).sort("id").head() == Row(0)) + #expect(try await spark.range(2).sort("id").head(1) == [Row(0)]) + #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)]) + #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)]) + await spark.stop() + } - @Test - func take() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).take(1).isEmpty) - #expect(try await spark.range(2).sort("id").take(1) == [Row(0)]) - #expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)]) - #expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)]) - await spark.stop() - } + @Test + func take() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).take(1).isEmpty) + #expect(try await spark.range(2).sort("id").take(1) == [Row(0)]) + #expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)]) + #expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)]) + await spark.stop() + } - @Test - func tail() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).tail(1).isEmpty) - #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)]) - #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)]) - #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)]) - await spark.stop() - } + @Test + func tail() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).tail(1).isEmpty) + #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)]) + #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)]) + #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)]) + await spark.stop() + } - @Test - func show() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.sql("SHOW TABLES").show() - try await spark.sql("SELECT * FROM VALUES (true, false)").show() - try await spark.sql("SELECT * FROM VALUES (1, 2)").show() - try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')").show() - - // Check all signatures - try await spark.range(1000).show() - try await spark.range(1000).show(1) - try await spark.range(1000).show(true) - try await spark.range(1000).show(false) - try await spark.range(1000).show(1, true) - try await spark.range(1000).show(1, false) - try await spark.range(1000).show(1, 20) - try await spark.range(1000).show(1, 20, true) - try await spark.range(1000).show(1, 20, false) - - await spark.stop() - } + @Test + func show() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql("SHOW TABLES").show() + try await spark.sql("SELECT * FROM VALUES (true, false)").show() + try await spark.sql("SELECT * FROM VALUES (1, 2)").show() + try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')").show() + + // Check all signatures + try await spark.range(1000).show() + try await spark.range(1000).show(1) + try await spark.range(1000).show(true) + try await spark.range(1000).show(false) + try await spark.range(1000).show(1, true) + try await spark.range(1000).show(1, false) + try await spark.range(1000).show(1, 20) + try await spark.range(1000).show(1, 20, true) + try await spark.range(1000).show(1, 20, false) - @Test - func showNull() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.sql( - "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" - ).show() - await spark.stop() - } + await spark.stop() + } - @Test - func showCommand() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.sql("DROP TABLE IF EXISTS t").show() - await spark.stop() - } + @Test + func showNull() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql( + "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" + ).show() + await spark.stop() + } - @Test - func cache() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(10).cache().count() == 10) - await spark.stop() - } + @Test + func showCommand() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql("DROP TABLE IF EXISTS t").show() + await spark.stop() + } - @Test - func checkpoint() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version >= "4.0.0" { - // By default, reliable checkpoint location is required. - try await #require(throws: Error.self) { - try await spark.range(10).checkpoint() - } - // Checkpointing with unreliable checkpoint - let df = try await spark.range(10).checkpoint(true, false) - #expect(try await df.count() == 10) - } - await spark.stop() - } + @Test + func cache() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).cache().count() == 10) + await spark.stop() + } - @Test - func localCheckpoint() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version >= "4.0.0" { - #expect(try await spark.range(10).localCheckpoint().count() == 10) + @Test + func checkpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version >= "4.0.0" { + // By default, reliable checkpoint location is required. + try await #require(throws: Error.self) { + try await spark.range(10).checkpoint() } - await spark.stop() + // Checkpointing with unreliable checkpoint + let df = try await spark.range(10).checkpoint(true, false) + #expect(try await df.count() == 10) } + await spark.stop() + } - @Test - func persist() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(20).persist().count() == 20) - #expect( - try await spark.range(21).persist(storageLevel: StorageLevel.MEMORY_ONLY).count() == 21) - await spark.stop() + @Test + func localCheckpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version >= "4.0.0" { + #expect(try await spark.range(10).localCheckpoint().count() == 10) } + await spark.stop() + } - @Test - func persistInvalidStorageLevel() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await #require(throws: Error.self) { - var invalidLevel = StorageLevel.DISK_ONLY - invalidLevel.replication = 0 - try await spark.range(9999).persist(storageLevel: invalidLevel).count() - } - await spark.stop() - } + @Test + func persist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(20).persist().count() == 20) + #expect( + try await spark.range(21).persist(storageLevel: StorageLevel.MEMORY_ONLY).count() == 21) + await spark.stop() + } - @Test - func unpersist() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(30) - #expect(try await df.persist().count() == 30) - #expect(try await df.unpersist().count() == 30) - await spark.stop() + @Test + func persistInvalidStorageLevel() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await #require(throws: Error.self) { + var invalidLevel = StorageLevel.DISK_ONLY + invalidLevel.replication = 0 + try await spark.range(9999).persist(storageLevel: invalidLevel).count() } + await spark.stop() + } + + @Test + func unpersist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(30) + #expect(try await df.persist().count() == 30) + #expect(try await df.unpersist().count() == 30) + await spark.stop() + } + + @Test + func join() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS T(a, b)") + let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS S(c, b)") + let expectedCross = [ + Row("a", 1, "c", 2), + Row("a", 1, "d", 3), + Row("b", 2, "c", 2), + Row("b", 2, "d", 3), + ] + #expect(try await df1.join(df2).collect() == expectedCross) + #expect(try await df1.crossJoin(df2).collect() == expectedCross) + + #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")]) + #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")]) + + #expect( + try await df1.join(df2, "b", "left").collect() == [Row(1, "a", nil), Row(2, "b", "c")]) + #expect( + try await df1.join(df2, "b", "right").collect() == [Row(2, "b", "c"), Row(3, nil, "d")]) + #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")]) + #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")]) + + let expectedOuter = [ + Row(1, "a", nil), + Row(2, "b", "c"), + Row(3, nil, "d"), + ] + #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter) + #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter) + #expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter) + + let expected = [Row("b", 2, "c", 2)] + #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == expected) + #expect( + try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) + await spark.stop() + } - @Test - func join() async throws { - let spark = try await SparkSession.builder.getOrCreate() + @Test + func lateralJoin() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version.starts(with: "4.") { let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS T(a, b)") let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS S(c, b)") let expectedCross = [ @@ -548,407 +586,365 @@ struct DataFrameTests { Row("b", 2, "c", 2), Row("b", 2, "d", 3), ] - #expect(try await df1.join(df2).collect() == expectedCross) - #expect(try await df1.crossJoin(df2).collect() == expectedCross) - - #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")]) - #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")]) - - #expect( - try await df1.join(df2, "b", "left").collect() == [Row(1, "a", nil), Row(2, "b", "c")]) - #expect( - try await df1.join(df2, "b", "right").collect() == [Row(2, "b", "c"), Row(3, nil, "d")]) - #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")]) - #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")]) - - let expectedOuter = [ - Row(1, "a", nil), - Row(2, "b", "c"), - Row(3, nil, "d"), - ] - #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter) - #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter) - #expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter) + #expect(try await df1.lateralJoin(df2).collect() == expectedCross) + #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == expectedCross) let expected = [Row("b", 2, "c", 2)] - #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == expected) + #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() == expected) #expect( - try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) - await spark.stop() + try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() + == expected) } + await spark.stop() + } - @Test - func lateralJoin() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version.starts(with: "4.") { - let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS T(a, b)") - let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS S(c, b)") - let expectedCross = [ - Row("a", 1, "c", 2), - Row("a", 1, "d", 3), - Row("b", 2, "c", 2), - Row("b", 2, "d", 3), - ] - #expect(try await df1.lateralJoin(df2).collect() == expectedCross) - #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == expectedCross) - - let expected = [Row("b", 2, "c", 2)] - #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() == expected) - #expect( - try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() - == expected) - } - await spark.stop() - } + @Test + func except() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.except(spark.range(1, 5)).collect() == []) + #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)]) + #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1), Row(2)]) + #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0) + await spark.stop() + } - @Test - func except() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.except(spark.range(1, 5)).collect() == []) - #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)]) - #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1), Row(2)]) - #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0) - await spark.stop() - } + @Test + func exceptAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.exceptAll(spark.range(1, 5)).collect() == []) + #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)]) + #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1), Row(2)]) + #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1) + await spark.stop() + } - @Test - func exceptAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.exceptAll(spark.range(1, 5)).collect() == []) - #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)]) - #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1), Row(2)]) - #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1) - await spark.stop() - } + @Test + func intersect() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1), Row(2)]) + #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)]) + #expect(try await df.intersect(spark.range(3, 5)).collect() == []) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.intersect(df2).count() == 1) + await spark.stop() + } - @Test - func intersect() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1), Row(2)]) - #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)]) - #expect(try await df.intersect(spark.range(3, 5)).collect() == []) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.intersect(df2).count() == 1) - await spark.stop() - } + @Test + func intersectAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row(1), Row(2)]) + #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row(2)]) + #expect(try await df.intersectAll(spark.range(3, 5)).collect() == []) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.intersectAll(df2).count() == 2) + await spark.stop() + } - @Test - func intersectAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row(1), Row(2)]) - #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row(2)]) - #expect(try await df.intersectAll(spark.range(3, 5)).collect() == []) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.intersectAll(df2).count() == 2) - await spark.stop() - } + @Test + func union() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 2) + #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) + #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1), Row(2)]) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.union(df2).count() == 4) + await spark.stop() + } - @Test - func union() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 2) - #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) - #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1), Row(2)]) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.union(df2).count() == 4) - await spark.stop() - } + @Test + func unionAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 2) + #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) + #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1), Row(2)]) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.unionAll(df2).count() == 4) + await spark.stop() + } - @Test - func unionAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 2) - #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) - #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1), Row(2)]) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.unionAll(df2).count() == 4) - await spark.stop() - } + @Test + func unionByName() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df1 = try await spark.sql("SELECT 1 a, 2 b") + let df2 = try await spark.sql("SELECT 4 b, 3 a") + #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3, 4)]) + #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)]) + let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df3.unionByName(df3).count() == 4) + await spark.stop() + } - @Test - func unionByName() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df1 = try await spark.sql("SELECT 1 a, 2 b") - let df2 = try await spark.sql("SELECT 4 b, 3 a") - #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3, 4)]) - #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)]) - let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df3.unionByName(df3).count() == 4) - await spark.stop() - } + @Test + func repartition() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 3, 5] as [Int32] { + try await df.repartition(n).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } - @Test - func repartition() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tmpDir = "/tmp/" + UUID().uuidString - let df = try await spark.range(2025) - for n in [1, 3, 5] as [Int32] { - try await df.repartition(n).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - } - try await spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - await spark.stop() - } + @Test + func repartitionByExpression() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 3, 5] as [Int32] { + try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + try await df.repartitionByExpression(n, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).repartition(10, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + try await spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } - @Test - func repartitionByExpression() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tmpDir = "/tmp/" + UUID().uuidString - let df = try await spark.range(2025) - for n in [1, 3, 5] as [Int32] { - try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - try await df.repartitionByExpression(n, "id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - } - try await spark.range(1).repartition(10, "id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - try await spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - await spark.stop() - } + @Test + func coalesce() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 2, 3] as [Int32] { + try await df.coalesce(n).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } - @Test - func coalesce() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tmpDir = "/tmp/" + UUID().uuidString - let df = try await spark.range(2025) - for n in [1, 2, 3] as [Int32] { - try await df.coalesce(n).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - } - try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - await spark.stop() - } + @Test + func distinct() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") + #expect(try await df.distinct().count() == 3) + await spark.stop() + } - @Test - func distinct() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") - #expect(try await df.distinct().count() == 3) - await spark.stop() - } + @Test + func dropDuplicates() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") + #expect(try await df.dropDuplicates().count() == 3) + #expect(try await df.dropDuplicates("a").count() == 3) + await spark.stop() + } - @Test - func dropDuplicates() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") - #expect(try await df.dropDuplicates().count() == 3) - #expect(try await df.dropDuplicates("a").count() == 3) - await spark.stop() - } + @Test + func dropDuplicatesWithinWatermark() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") + #expect(try await df.dropDuplicatesWithinWatermark().count() == 3) + #expect(try await df.dropDuplicatesWithinWatermark("a").count() == 3) + await spark.stop() + } - @Test - func dropDuplicatesWithinWatermark() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") - #expect(try await df.dropDuplicatesWithinWatermark().count() == 3) - #expect(try await df.dropDuplicatesWithinWatermark("a").count() == 3) - await spark.stop() - } + @Test + func withWatermark() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = + try await spark + .sql( + """ + SELECT * FROM VALUES + (1, now()), + (1, now() - INTERVAL 1 HOUR), + (1, now() - INTERVAL 2 HOUR) + T(data, eventTime) + """ + ) + .withWatermark("eventTime", "1 minute") // This tests only API for now + #expect(try await df.dropDuplicatesWithinWatermark("data").count() == 1) + await spark.stop() + } - @Test - func withWatermark() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = - try await spark - .sql( - """ - SELECT * FROM VALUES - (1, now()), - (1, now() - INTERVAL 1 HOUR), - (1, now() - INTERVAL 2 HOUR) - T(data, eventTime) - """ - ) - .withWatermark("eventTime", "1 minute") // This tests only API for now - #expect(try await df.dropDuplicatesWithinWatermark("data").count() == 1) - await spark.stop() - } + @Test + func describe() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(10) + let expected = [Row("10"), Row("4.5"), Row("3.0276503540974917"), Row("0"), Row("9")] + #expect(try await df.describe().select("id").collect() == expected) + #expect(try await df.describe("id").select("id").collect() == expected) + await spark.stop() + } - @Test - func describe() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(10) - let expected = [Row("10"), Row("4.5"), Row("3.0276503540974917"), Row("0"), Row("9")] - #expect(try await df.describe().select("id").collect() == expected) - #expect(try await df.describe("id").select("id").collect() == expected) - await spark.stop() - } + @Test + func summary() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = [ + Row("10"), Row("4.5"), Row("3.0276503540974917"), + Row("0"), Row("2"), Row("4"), Row("7"), Row("9"), + ] + #expect(try await spark.range(10).summary().select("id").collect() == expected) + #expect( + try await spark.range(10).summary("min", "max").select("id").collect() == [ + Row("0"), Row("9"), + ]) + await spark.stop() + } - @Test - func summary() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let expected = [ - Row("10"), Row("4.5"), Row("3.0276503540974917"), - Row("0"), Row("2"), Row("4"), Row("7"), Row("9"), - ] - #expect(try await spark.range(10).summary().select("id").collect() == expected) - #expect( - try await spark.range(10).summary("min", "max").select("id").collect() == [ - Row("0"), Row("9"), - ]) - await spark.stop() - } + @Test + func groupBy() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(3).groupBy("id").agg("count(*)", "sum(*)", "avg(*)") + .collect() + #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2, 2.0)]) + await spark.stop() + } - @Test - func groupBy() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.range(3).groupBy("id").agg("count(*)", "sum(*)", "avg(*)") - .collect() - #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2, 2.0)]) - await spark.stop() - } + @Test + func rollup() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model") + .agg("sum(quantity) sum").orderBy("city", "car_model").collect() + #expect( + rows == [ + Row("Dublin", "Honda Accord", 10), + Row("Dublin", "Honda CRV", 3), + Row("Dublin", "Honda Civic", 20), + Row("Dublin", nil, 33), + Row("Fremont", "Honda Accord", 15), + Row("Fremont", "Honda CRV", 7), + Row("Fremont", "Honda Civic", 10), + Row("Fremont", nil, 32), + Row("San Jose", "Honda Accord", 8), + Row("San Jose", "Honda Civic", 5), + Row("San Jose", nil, 13), + Row(nil, nil, 78), + ]) + await spark.stop() + } - @Test - func rollup() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model") - .agg("sum(quantity) sum").orderBy("city", "car_model").collect() - #expect( - rows == [ - Row("Dublin", "Honda Accord", 10), - Row("Dublin", "Honda CRV", 3), - Row("Dublin", "Honda Civic", 20), - Row("Dublin", nil, 33), - Row("Fremont", "Honda Accord", 15), - Row("Fremont", "Honda CRV", 7), - Row("Fremont", "Honda Civic", 10), - Row("Fremont", nil, 32), - Row("San Jose", "Honda Accord", 8), - Row("San Jose", "Honda Civic", 5), - Row("San Jose", nil, 13), - Row(nil, nil, 78), - ]) - await spark.stop() - } + @Test + func cube() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model") + .agg("sum(quantity) sum").orderBy("city", "car_model").collect() + #expect( + rows == [ + Row("Dublin", "Honda Accord", 10), + Row("Dublin", "Honda CRV", 3), + Row("Dublin", "Honda Civic", 20), + Row("Dublin", nil, 33), + Row("Fremont", "Honda Accord", 15), + Row("Fremont", "Honda CRV", 7), + Row("Fremont", "Honda Civic", 10), + Row("Fremont", nil, 32), + Row("San Jose", "Honda Accord", 8), + Row("San Jose", "Honda Civic", 5), + Row("San Jose", nil, 13), + Row(nil, "Honda Accord", 33), + Row(nil, "Honda CRV", 10), + Row(nil, "Honda Civic", 35), + Row(nil, nil, 78), + ]) + await spark.stop() + } - @Test - func cube() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model") - .agg("sum(quantity) sum").orderBy("city", "car_model").collect() - #expect( - rows == [ - Row("Dublin", "Honda Accord", 10), - Row("Dublin", "Honda CRV", 3), - Row("Dublin", "Honda Civic", 20), - Row("Dublin", nil, 33), - Row("Fremont", "Honda Accord", 15), - Row("Fremont", "Honda CRV", 7), - Row("Fremont", "Honda Civic", 10), - Row("Fremont", nil, 32), - Row("San Jose", "Honda Accord", 8), - Row("San Jose", "Honda Civic", 5), - Row("San Jose", nil, 13), - Row(nil, "Honda Accord", 33), - Row(nil, "Honda CRV", 10), - Row(nil, "Honda Civic", 35), - Row(nil, nil, 78), - ]) - await spark.stop() - } + @Test + func toJSON() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(2).toJSON() + #expect(try await df.columns == ["to_json(struct(id))"]) + #expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")]) - @Test - func toJSON() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(2).toJSON() - #expect(try await df.columns == ["to_json(struct(id))"]) - #expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")]) + let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")] + #expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected) + await spark.stop() + } - let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")] - #expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected) - await spark.stop() - } + @Test + func unpivot() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql( + """ + SELECT * FROM + VALUES (1, 11, 12L), + (2, 21, 22L) + T(id, int, long) + """) + let expected = [ + Row(1, "int", 11), + Row(1, "long", 12), + Row(2, "int", 21), + Row(2, "long", 22), + ] + #expect( + try await df.unpivot(["id"], ["int", "long"], "variable", "value").collect() == expected) + #expect( + try await df.melt(["id"], ["int", "long"], "variable", "value").collect() == expected) + await spark.stop() + } + + @Test + func transpose() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version.starts(with: "4.") { + #expect(try await spark.range(1).transpose().columns == ["key", "0"]) + #expect(try await spark.range(1).transpose().count() == 0) - @Test - func unpivot() async throws { - let spark = try await SparkSession.builder.getOrCreate() let df = try await spark.sql( """ SELECT * FROM - VALUES (1, 11, 12L), - (2, 21, 22L) - T(id, int, long) + VALUES ('A', 1, 2), + ('B', 3, 4) + T(id, val1, val2) """) let expected = [ - Row(1, "int", 11), - Row(1, "long", 12), - Row(2, "int", 21), - Row(2, "long", 22), + Row("val1", 1, 3), + Row("val2", 2, 4), ] - #expect( - try await df.unpivot(["id"], ["int", "long"], "variable", "value").collect() == expected) - #expect( - try await df.melt(["id"], ["int", "long"], "variable", "value").collect() == expected) - await spark.stop() - } - - @Test - func transpose() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version.starts(with: "4.") { - #expect(try await spark.range(1).transpose().columns == ["key", "0"]) - #expect(try await spark.range(1).transpose().count() == 0) - - let df = try await spark.sql( - """ - SELECT * FROM - VALUES ('A', 1, 2), - ('B', 3, 4) - T(id, val1, val2) - """) - let expected = [ - Row("val1", 1, 3), - Row("val2", 2, 4), - ] - #expect(try await df.transpose().collect() == expected) - #expect(try await df.transpose("id").collect() == expected) - } - await spark.stop() + #expect(try await df.transpose().collect() == expected) + #expect(try await df.transpose("id").collect() == expected) } + await spark.stop() + } - @Test - func decimal() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql( - """ - SELECT * FROM VALUES - (1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)), - (2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL)) - """) - #expect( - try await df.dtypes.map { $0.1 } == [ - "decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)", - ]) - let expected = [ - Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)), - Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil), - ] - #expect(try await df.collect() == expected) - await spark.stop() - } + @Test + func decimal() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql( + """ + SELECT * FROM VALUES + (1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)), + (2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL)) + """) + #expect( + try await df.dtypes.map { $0.1 } == [ + "decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)", + ]) + let expected = [ + Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)), + Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil), + ] + #expect(try await df.collect() == expected) + await spark.stop() + } - @Test - func timestamp() async throws { - let spark = try await SparkSession.builder.getOrCreate() - // TODO(SPARK-52747) - let df = try await spark.sql( - "SELECT TIMESTAMP '2025-05-01 16:23:40Z', TIMESTAMP '2025-05-01 16:23:40.123456Z'") - let expected = [ - Row( - Date(timeIntervalSince1970: 1746116620.0), Date(timeIntervalSince1970: 1746116620.123456)) - ] - #expect(try await df.collect() == expected) - await spark.stop() - } - #endif + @Test + func timestamp() async throws { + let spark = try await SparkSession.builder.getOrCreate() + // TODO(SPARK-52747) + let df = try await spark.sql( + "SELECT TIMESTAMP '2025-05-01 16:23:40Z', TIMESTAMP '2025-05-01 16:23:40.123456Z'") + let expected = [ + Row( + Date(timeIntervalSince1970: 1746116620.0), Date(timeIntervalSince1970: 1746116620.123456)) + ] + #expect(try await df.collect() == expected) + await spark.stop() + } @Test func storageLevel() async throws { diff --git a/Tests/SparkConnectTests/IcebergTests.swift b/Tests/SparkConnectTests/IcebergTests.swift index 94c6a8a..70095a5 100644 --- a/Tests/SparkConnectTests/IcebergTests.swift +++ b/Tests/SparkConnectTests/IcebergTests.swift @@ -99,14 +99,12 @@ struct IcebergTests { WHEN NOT MATCHED BY SOURCE THEN UPDATE SET data = 'invalid' """ ).count() - #if !os(Linux) - let expected = [ - Row(2, "updated"), - Row(3, "invalid"), - Row(4, "new"), - ] - #expect(try await spark.table(t2).collect() == expected) - #endif + let expected = [ + Row(2, "updated"), + Row(3, "invalid"), + Row(4, "new"), + ] + #expect(try await spark.table(t2).collect() == expected) }) await spark.stop() } diff --git a/Tests/SparkConnectTests/SQLTests.swift b/Tests/SparkConnectTests/SQLTests.swift index c6dc66e..45dc07b 100644 --- a/Tests/SparkConnectTests/SQLTests.swift +++ b/Tests/SparkConnectTests/SQLTests.swift @@ -100,46 +100,44 @@ struct SQLTests { ] let queriesForSpark41Only: [String] = [ - "time.sql", + "time.sql" ] - #if !os(Linux) - @Test - func runAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let MAX = Int32.max - for name in try! fm.contentsOfDirectory(atPath: path).sorted() { - guard name.hasSuffix(".sql") else { continue } - print(name) - if await !spark.version.starts(with: "4.") && queriesForSpark4Only.contains(name) { - print("Skip query \(name) due to the difference between Spark 3 and 4.") - continue - } - if await !spark.version.starts(with: "4.1") && queriesForSpark41Only.contains(name) { - print("Skip query \(name) due to the difference between Spark 4.0 and 4.1") - continue - } + @Test + func runAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let MAX = Int32.max + for name in try! fm.contentsOfDirectory(atPath: path).sorted() { + guard name.hasSuffix(".sql") else { continue } + print(name) + if await !spark.version.starts(with: "4.") && queriesForSpark4Only.contains(name) { + print("Skip query \(name) due to the difference between Spark 3 and 4.") + continue + } + if await !spark.version.starts(with: "4.1") && queriesForSpark41Only.contains(name) { + print("Skip query \(name) due to the difference between Spark 4.0 and 4.1") + continue + } - let sql = try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name)"), encoding: .utf8) - let result = - try await spark.sql(sql).showString(MAX, MAX, false).collect()[0].get(0) as! String - let answer = cleanUp(result.trimmingCharacters(in: .whitespacesAndNewlines)) - if regenerateGoldenFiles { - let path = - "\(FileManager.default.currentDirectoryPath)/Tests/SparkConnectTests/Resources/queries/\(name).answer" - fm.createFile(atPath: path, contents: answer.data(using: .utf8)!, attributes: nil) - } else { - let expected = cleanUp( - try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name).answer"), encoding: .utf8) - ) - .trimmingCharacters(in: .whitespacesAndNewlines) - if answer != expected { - print("Try to compare normalized result.") - #expect(normalize(answer) == normalize(expected)) - } + let sql = try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name)"), encoding: .utf8) + let result = + try await spark.sql(sql).showString(MAX, MAX, false).collect()[0].get(0) as! String + let answer = cleanUp(result.trimmingCharacters(in: .whitespacesAndNewlines)) + if regenerateGoldenFiles { + let path = + "\(FileManager.default.currentDirectoryPath)/Tests/SparkConnectTests/Resources/queries/\(name).answer" + fm.createFile(atPath: path, contents: answer.data(using: .utf8)!, attributes: nil) + } else { + let expected = cleanUp( + try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name).answer"), encoding: .utf8) + ) + .trimmingCharacters(in: .whitespacesAndNewlines) + if answer != expected { + print("Try to compare normalized result.") + #expect(normalize(answer) == normalize(expected)) } } - await spark.stop() } - #endif + await spark.stop() + } } diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 5c3e634..787b3f6 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -130,76 +130,74 @@ struct SparkSessionTests { await spark.stop() } - #if !os(Linux) - @Test - func sql() async throws { - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - let expected = [Row(true, 1, "a")] - if await spark.version.starts(with: "4.") { - #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected) - #expect( - try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() - == expected) - } - await spark.stop() + @Test + func sql() async throws { + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + let expected = [Row(true, 1, "a")] + if await spark.version.starts(with: "4.") { + #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected) + #expect( + try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() + == expected) } + await spark.stop() + } - @Test - func addInvalidArtifact() async throws { - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - await #expect(throws: SparkConnectError.InvalidArgument) { - try await spark.addArtifact("x.txt") - } - await spark.stop() + @Test + func addInvalidArtifact() async throws { + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + await #expect(throws: SparkConnectError.InvalidArgument) { + try await spark.addArtifact("x.txt") } + await spark.stop() + } - @Test - func addArtifact() async throws { - let fm = FileManager() - let path = "my.jar" - let url = URL(fileURLWithPath: path) + @Test + func addArtifact() async throws { + let fm = FileManager() + let path = "my.jar" + let url = URL(fileURLWithPath: path) - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) - if await spark.version.starts(with: "4.") { - try await spark.addArtifact(path) - try await spark.addArtifact(url) - } - try fm.removeItem(atPath: path) - await spark.stop() + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) + if await spark.version.starts(with: "4.") { + try await spark.addArtifact(path) + try await spark.addArtifact(url) } + try fm.removeItem(atPath: path) + await spark.stop() + } - @Test - func addArtifacts() async throws { - let fm = FileManager() - let path = "my.jar" - let url = URL(fileURLWithPath: path) + @Test + func addArtifacts() async throws { + let fm = FileManager() + let path = "my.jar" + let url = URL(fileURLWithPath: path) - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) - if await spark.version.starts(with: "4.") { - try await spark.addArtifacts(url, url) - } - try fm.removeItem(atPath: path) - await spark.stop() + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) + if await spark.version.starts(with: "4.") { + try await spark.addArtifacts(url, url) } + try fm.removeItem(atPath: path) + await spark.stop() + } - @Test - func executeCommand() async throws { - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version.starts(with: "4.") { - await #expect(throws: SparkConnectError.DataSourceNotFound) { - try await spark.executeCommand("runner", "command", [:]).show() - } + @Test + func executeCommand() async throws { + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version.starts(with: "4.") { + await #expect(throws: SparkConnectError.DataSourceNotFound) { + try await spark.executeCommand("runner", "command", [:]).show() } - await spark.stop() } - #endif + await spark.stop() + } @Test func table() async throws { @@ -218,10 +216,8 @@ struct SparkSessionTests { await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.time(spark.range(1000).count) == 1000) - #if !os(Linux) - #expect(try await spark.time(spark.range(1).collect) == [Row(0)]) - try await spark.time(spark.range(10).show) - #endif + #expect(try await spark.time(spark.range(1).collect) == [Row(0)]) + try await spark.time(spark.range(10).show) await spark.stop() }