Skip to content

Commit f97822d

Browse files
committed
[SPARK-51626] Support DataFrameReader
### What changes were proposed in this pull request? This PR aims to support `DataFrameReader`. ### Why are the changes needed? For the feature parity. ### Does this PR introduce _any_ user-facing change? No, this is a new addition to the unreleased version. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #29 from dongjoon-hyun/SPARK-51626. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 3fe1f0e commit f97822d

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Atomics
20+
import Foundation
21+
import GRPCCore
22+
import GRPCNIOTransportHTTP2
23+
import GRPCProtobuf
24+
import NIOCore
25+
import SwiftyTextTable
26+
import Synchronization
27+
28+
/// An interface used to load a `DataFrame` from external storage systems
29+
/// (e.g. file systems, key-value stores, etc). Use `SparkSession.read` to access this.
30+
public actor DataFrameReader: Sendable {
31+
var source: String = ""
32+
33+
var paths: [String] = []
34+
35+
// TODO: Case-insensitive Map
36+
var extraOptions: [String: String] = [:]
37+
38+
let sparkSession: SparkSession
39+
40+
init(sparkSession: SparkSession) {
41+
self.sparkSession = sparkSession
42+
}
43+
44+
/// Specifies the input data source format.
45+
/// - Parameter source: A string.
46+
/// - Returns: A `DataFrameReader`.
47+
public func format(_ source: String) -> DataFrameReader {
48+
self.source = source
49+
return self
50+
}
51+
52+
/// Adds an input option for the underlying data source.
53+
/// - Parameters:
54+
/// - key: A key string.
55+
/// - value: A value string.
56+
/// - Returns: A `DataFrameReader`.
57+
public func option(_ key: String, _ value: String) -> DataFrameReader {
58+
self.extraOptions[key] = value
59+
return self
60+
}
61+
62+
/// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
63+
/// key-value stores).
64+
/// - Returns: A `DataFrame`.
65+
public func load() -> DataFrame {
66+
return load([])
67+
}
68+
69+
/// Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a
70+
/// local or distributed file system).
71+
/// - Parameter path: A path string.
72+
/// - Returns: A `DataFrame`.
73+
public func load(_ path: String) -> DataFrame {
74+
return load([path])
75+
}
76+
77+
/// Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if
78+
/// the source is a HadoopFsRelationProvider.
79+
/// - Parameter paths: An array of path strings.
80+
/// - Returns: A `DataFrame`.
81+
public func load(_ paths: [String]) -> DataFrame {
82+
self.paths = paths
83+
84+
var dataSource = DataSource()
85+
dataSource.format = self.source
86+
dataSource.paths = self.paths
87+
dataSource.options = self.extraOptions
88+
89+
var read = Read()
90+
read.dataSource = dataSource
91+
92+
var relation = Relation()
93+
relation.read = read
94+
95+
var plan = Plan()
96+
plan.opType = .root(relation)
97+
98+
return DataFrame(spark: sparkSession, plan: plan)
99+
}
100+
101+
/// Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
102+
/// overloaded `csv()` method for more details.
103+
/// - Parameter path: A path string
104+
/// - Returns: A `DataFrame`.
105+
public func csv(_ path: String) -> DataFrame {
106+
self.source = "csv"
107+
return load(path)
108+
}
109+
110+
/// Loads CSV files and returns the result as a `DataFrame`.
111+
/// This function will go through the input once to determine the input schema if `inferSchema`
112+
/// is enabled. To avoid going through the entire data once, disable `inferSchema` option or
113+
/// specify the schema explicitly using `schema`.
114+
/// - Parameter paths: Path strings.
115+
/// - Returns: A `DataFrame`.
116+
public func csv(_ paths: String...) -> DataFrame {
117+
self.source = "csv"
118+
return load(paths)
119+
}
120+
121+
/// Loads a JSON file and returns the result as a `DataFrame`.
122+
/// - Parameter path: A path string
123+
/// - Returns: A `DataFrame`.
124+
public func json(_ path: String) -> DataFrame {
125+
self.source = "json"
126+
return load(path)
127+
}
128+
129+
/// Loads JSON files and returns the result as a `DataFrame`.
130+
/// - Parameter paths: Path strings
131+
/// - Returns: A `DataFrame`.
132+
public func json(_ paths: String...) -> DataFrame {
133+
self.source = "json"
134+
return load(paths)
135+
}
136+
137+
/// Loads an ORC file and returns the result as a `DataFrame`.
138+
/// - Parameter path: A path string
139+
/// - Returns: A `DataFrame`.
140+
public func orc(_ path: String) -> DataFrame {
141+
self.source = "orc"
142+
return load(path)
143+
}
144+
145+
/// Loads ORC files and returns the result as a `DataFrame`.
146+
/// - Parameter paths: Path strings
147+
/// - Returns: A `DataFrame`.
148+
public func orc(_ paths: String...) -> DataFrame {
149+
self.source = "orc"
150+
return load(paths)
151+
}
152+
153+
/// Loads a Parquet file and returns the result as a `DataFrame`.
154+
/// - Parameter path: A path string
155+
/// - Returns: A `DataFrame`.
156+
public func parquet(_ path: String) -> DataFrame {
157+
self.source = "parquet"
158+
return load(path)
159+
}
160+
161+
/// Loads Parquet files, returning the result as a `DataFrame`.
162+
/// - Parameter paths: Path strings
163+
/// - Returns: A `DataFrame`.
164+
public func parquet(_ paths: String...) -> DataFrame {
165+
self.source = "parquet"
166+
return load(paths)
167+
}
168+
}

Sources/SparkConnect/SparkSession.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ public actor SparkSession {
116116
return try await DataFrame(spark: self, sqlText: sqlText)
117117
}
118118

119+
var read: DataFrameReader {
120+
get {
121+
return DataFrameReader(sparkSession: self)
122+
}
123+
}
124+
119125
/// This is defined as the return type of `SparkSession.sparkContext` method.
120126
/// This is an empty `Struct` type because `sparkContext` method is designed to throw
121127
/// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`.

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
2020
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
2121
typealias ConfigRequest = Spark_Connect_ConfigRequest
22+
typealias DataSource = Spark_Connect_Read.DataSource
2223
typealias DataType = Spark_Connect_DataType
2324
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
2425
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
@@ -29,6 +30,7 @@ typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
2930
typealias Plan = Spark_Connect_Plan
3031
typealias Project = Spark_Connect_Project
3132
typealias Range = Spark_Connect_Range
33+
typealias Read = Spark_Connect_Read
3234
typealias Relation = Spark_Connect_Relation
3335
typealias SparkConnectService = Spark_Connect_SparkConnectService
3436
typealias Sort = Spark_Connect_Sort
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import Testing
22+
23+
@testable import SparkConnect
24+
25+
/// A test suite for `DataFrameReader`
26+
struct DataFrameReaderTests {
27+
28+
@Test
29+
func csv() async throws {
30+
let spark = try await SparkSession.builder.getOrCreate()
31+
let path = "../examples/src/main/resources/people.csv"
32+
#expect(try await spark.read.format("csv").load(path).count() == 3)
33+
#expect(try await spark.read.csv(path).count() == 3)
34+
#expect(try await spark.read.csv(path, path).count() == 6)
35+
await spark.stop()
36+
}
37+
38+
@Test
39+
func json() async throws {
40+
let spark = try await SparkSession.builder.getOrCreate()
41+
let path = "../examples/src/main/resources/people.json"
42+
#expect(try await spark.read.format("json").load(path).count() == 3)
43+
#expect(try await spark.read.json(path).count() == 3)
44+
#expect(try await spark.read.json(path, path).count() == 6)
45+
await spark.stop()
46+
}
47+
48+
@Test
49+
func orc() async throws {
50+
let spark = try await SparkSession.builder.getOrCreate()
51+
let path = "../examples/src/main/resources/users.orc"
52+
#expect(try await spark.read.format("orc").load(path).count() == 2)
53+
#expect(try await spark.read.orc(path).count() == 2)
54+
#expect(try await spark.read.orc(path, path).count() == 4)
55+
await spark.stop()
56+
}
57+
58+
@Test
59+
func parquet() async throws {
60+
let spark = try await SparkSession.builder.getOrCreate()
61+
let path = "../examples/src/main/resources/users.parquet"
62+
#expect(try await spark.read.format("parquet").load(path).count() == 2)
63+
#expect(try await spark.read.parquet(path).count() == 2)
64+
#expect(try await spark.read.parquet(path, path).count() == 4)
65+
await spark.stop()
66+
}
67+
}

0 commit comments

Comments
 (0)