Skip to content

Commit 7ab1e45

Browse files
committed
[SPARK-51799] Support user-specified schema in DataFrameReader
### What changes were proposed in this pull request? This PR aims to support user-specified schema in `DataFrameReader`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #58 from dongjoon-hyun/SPARK-51799. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 49a7739 commit 7ab1e45

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

Sources/SparkConnect/DataFrameReader.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ public actor DataFrameReader: Sendable {
3434

3535
var extraOptions: CaseInsensitiveDictionary = CaseInsensitiveDictionary([:])
3636

37+
var userSpecifiedSchemaDDL: String? = nil
38+
3739
let sparkSession: SparkSession
3840

3941
init(sparkSession: SparkSession) {
@@ -85,6 +87,22 @@ public actor DataFrameReader: Sendable {
8587
return self
8688
}
8789

90+
/// Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema
91+
/// automatically from data. By specifying the schema here, the underlying data source can skip
92+
/// the schema inference step, and thus speed up data loading.
93+
/// - Parameter schema: A DDL schema string.
94+
/// - Returns: A `DataFrameReader`.
95+
public func schema(_ schema: String) async throws -> DataFrameReader {
96+
// Validate by parsing.
97+
do {
98+
_ = try await sparkSession.client.ddlParse(schema)
99+
} catch {
100+
throw SparkConnectError.InvalidTypeException
101+
}
102+
self.userSpecifiedSchemaDDL = schema
103+
return self
104+
}
105+
88106
/// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
89107
/// key-value stores).
90108
/// - Returns: A `DataFrame`.
@@ -111,6 +129,9 @@ public actor DataFrameReader: Sendable {
111129
dataSource.format = self.source
112130
dataSource.paths = self.paths
113131
dataSource.options = self.extraOptions.toStringDictionary()
132+
if let userSpecifiedSchemaDDL = self.userSpecifiedSchemaDDL {
133+
dataSource.schema = userSpecifiedSchemaDDL
134+
}
114135

115136
var read = Read()
116137
read.dataSource = dataSource

Tests/SparkConnectTests/DataFrameReaderTests.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,27 @@ struct DataFrameReaderTests {
8585
})
8686
await spark.stop()
8787
}
88+
89+
@Test
90+
func schema() async throws {
91+
let spark = try await SparkSession.builder.getOrCreate()
92+
let path = "../examples/src/main/resources/people.json"
93+
#expect(try await spark.read.schema("age SHORT").json(path).dtypes.count == 1)
94+
#expect(try await spark.read.schema("age SHORT").json(path).dtypes[0] == ("age", "smallint"))
95+
#expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[0] == ("age", "smallint"))
96+
#expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[1] == ("name", "string"))
97+
await spark.stop()
98+
}
99+
100+
@Test
101+
func invalidSchema() async throws {
102+
let spark = try await SparkSession.builder.getOrCreate()
103+
await #expect(throws: SparkConnectError.InvalidTypeException) {
104+
_ = try await spark.read.schema("invalid-name SHORT")
105+
}
106+
await #expect(throws: SparkConnectError.InvalidTypeException) {
107+
_ = try await spark.read.schema("age UNKNOWN_TYPE")
108+
}
109+
await spark.stop()
110+
}
88111
}

0 commit comments

Comments
 (0)