diff --git a/Sources/SparkConnect/DataFrameReader.swift b/Sources/SparkConnect/DataFrameReader.swift index ca00b07..626ff77 100644 --- a/Sources/SparkConnect/DataFrameReader.swift +++ b/Sources/SparkConnect/DataFrameReader.swift @@ -34,6 +34,8 @@ public actor DataFrameReader: Sendable { var extraOptions: CaseInsensitiveDictionary = CaseInsensitiveDictionary([:]) + var userSpecifiedSchemaDDL: String? = nil + let sparkSession: SparkSession init(sparkSession: SparkSession) { @@ -85,6 +87,22 @@ public actor DataFrameReader: Sendable { return self } + /// Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + /// automatically from data. By specifying the schema here, the underlying data source can skip + /// the schema inference step, and thus speed up data loading. + /// - Parameter schema: A DDL schema string. + /// - Returns: A `DataFrameReader`. + public func schema(_ schema: String) async throws -> DataFrameReader { + // Validate by parsing. + do { + _ = try await sparkSession.client.ddlParse(schema) + } catch { + throw SparkConnectError.InvalidTypeException + } + self.userSpecifiedSchemaDDL = schema + return self + } + /// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external /// key-value stores). /// - Returns: A `DataFrame`. @@ -111,6 +129,9 @@ public actor DataFrameReader: Sendable { dataSource.format = self.source dataSource.paths = self.paths dataSource.options = self.extraOptions.toStringDictionary() + if let userSpecifiedSchemaDDL = self.userSpecifiedSchemaDDL { + dataSource.schema = userSpecifiedSchemaDDL + } var read = Read() read.dataSource = dataSource diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift b/Tests/SparkConnectTests/DataFrameReaderTests.swift index 781282e..78968ec 100644 --- a/Tests/SparkConnectTests/DataFrameReaderTests.swift +++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift @@ -85,4 +85,27 @@ struct DataFrameReaderTests { }) await spark.stop() } + + @Test + func schema() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let path = "../examples/src/main/resources/people.json" + #expect(try await spark.read.schema("age SHORT").json(path).dtypes.count == 1) + #expect(try await spark.read.schema("age SHORT").json(path).dtypes[0] == ("age", "smallint")) + #expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[0] == ("age", "smallint")) + #expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[1] == ("name", "string")) + await spark.stop() + } + + @Test + func invalidSchema() async throws { + let spark = try await SparkSession.builder.getOrCreate() + await #expect(throws: SparkConnectError.InvalidTypeException) { + _ = try await spark.read.schema("invalid-name SHORT") + } + await #expect(throws: SparkConnectError.InvalidTypeException) { + _ = try await spark.read.schema("age UNKNOWN_TYPE") + } + await spark.stop() + } }