Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Sources/SparkConnect/DataFrameReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public actor DataFrameReader: Sendable {

var extraOptions: CaseInsensitiveDictionary = CaseInsensitiveDictionary([:])

var userSpecifiedSchemaDDL: String? = nil

let sparkSession: SparkSession

init(sparkSession: SparkSession) {
Expand Down Expand Up @@ -85,6 +87,22 @@ public actor DataFrameReader: Sendable {
return self
}

/// Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema
/// automatically from data. By specifying the schema here, the underlying data source can skip
/// the schema inference step, and thus speed up data loading.
/// - Parameter schema: A DDL schema string.
/// - Returns: A `DataFrameReader`.
public func schema(_ schema: String) async throws -> DataFrameReader {
// Validate by parsing.
do {
_ = try await sparkSession.client.ddlParse(schema)
} catch {
throw SparkConnectError.InvalidTypeException
}
self.userSpecifiedSchemaDDL = schema
return self
}

/// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
/// key-value stores).
/// - Returns: A `DataFrame`.
Expand All @@ -111,6 +129,9 @@ public actor DataFrameReader: Sendable {
dataSource.format = self.source
dataSource.paths = self.paths
dataSource.options = self.extraOptions.toStringDictionary()
if let userSpecifiedSchemaDDL = self.userSpecifiedSchemaDDL {
dataSource.schema = userSpecifiedSchemaDDL
}

var read = Read()
read.dataSource = dataSource
Expand Down
23 changes: 23 additions & 0 deletions Tests/SparkConnectTests/DataFrameReaderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,27 @@ struct DataFrameReaderTests {
})
await spark.stop()
}

@Test
func schema() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let path = "../examples/src/main/resources/people.json"
#expect(try await spark.read.schema("age SHORT").json(path).dtypes.count == 1)
#expect(try await spark.read.schema("age SHORT").json(path).dtypes[0] == ("age", "smallint"))
#expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[0] == ("age", "smallint"))
#expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[1] == ("name", "string"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test with comment & null constraint

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it's supported

But, according to the Apache Spark 4.0.0 RC4, it seems there are limitations.
spark-shell

$ bin/spark-shell
WARNING: Using incubator modules: jdk.incubator.vector
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 4.0.0
      /_/

Using Scala version 2.13.16 (OpenJDK 64-Bit Server VM, Java 17.0.14)
Type in expressions to have them evaluated.
Type :help for more information.
25/04/15 12:32:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Spark context Web UI available at http://localhost:4040
Spark context available as 'sc' (master = local[*], app id = local-1744687967546).
Spark session available as 'spark'.

scala> spark.read.schema("name STRING NOT NULL").json("examples/src/main/resources/people.json").printSchema
warning: 1 deprecation (since 2.13.3); for details, enable `:setting -deprecation` or `:replay -deprecation`
root
 |-- name: string (nullable = true)

spark-connect-shell

$ bin/spark-connect-shell --remote sc://localhost:15002
25/04/15 12:28:48 INFO DefaultAllocationManagerOption: allocation manager type not specified, using netty as the default type
25/04/15 12:28:48 INFO CheckAllocator: Using DefaultAllocationManager at memory/netty/DefaultAllocationManagerFactory.class
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 4.0.0
      /_/

Type in expressions to have them evaluated.
Spark connect server version 4.0.0.
Spark session available as 'spark'.

scala> spark.read.schema("name STRING").json("../examples/src/main/resources/people.json").printSchema
root
 |-- name: string (nullable = true)

scala> spark.read.schema("name STRING NOT NULL").json("../examples/src/main/resources/people.json").printSchema
root
 |-- name: string (nullable = true)

scala> spark.read.schema("name STRING NOT NULL").json("../examples/src/main/resources/people.json").show()
+-------+
|   name|
+-------+
|Michael|
|   Andy|
| Justin|
+-------+

For that part, let me dig more, @yaooqinn .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @dongjoon-hyun

await spark.stop()
}

@Test
func invalidSchema() async throws {
let spark = try await SparkSession.builder.getOrCreate()
await #expect(throws: SparkConnectError.InvalidTypeException) {
_ = try await spark.read.schema("invalid-name SHORT")
}
await #expect(throws: SparkConnectError.InvalidTypeException) {
_ = try await spark.read.schema("age UNKNOWN_TYPE")
}
await spark.stop()
}
}
Loading