Skip to content

Commit 9930db9

Browse files
committed
[SPARK-51626] Support DataFrameReader
1 parent 3fe1f0e commit 9930db9

File tree

4 files changed

+269
-0
lines changed

4 files changed

+269
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Atomics
20+
import Foundation
21+
import GRPCCore
22+
import GRPCNIOTransportHTTP2
23+
import GRPCProtobuf
24+
import NIOCore
25+
import SwiftyTextTable
26+
import Synchronization
27+
28+
/// An interface used to load a `DataFrame` from external storage systems
29+
/// (e.g. file systems, key-value stores, etc). Use `SparkSession.read` to access this.
30+
public actor DataFrameReader: Sendable {
31+
var source: String = ""
32+
33+
var paths: [String] = []
34+
35+
// TODO: Case-insensitive Map
36+
var extraOptions: [String: String] = [:]
37+
38+
let sparkSession: SparkSession
39+
40+
init(sparkSession: SparkSession) {
41+
self.sparkSession = sparkSession
42+
}
43+
44+
/// Specifies the input data source format.
45+
/// - Parameter source: A string.
46+
/// - Returns: A `DataFrameReader`.
47+
public func format(_ source: String) -> DataFrameReader {
48+
self.source = source
49+
return self
50+
}
51+
52+
/// Adds an input option for the underlying data source.
53+
/// - Parameters:
54+
/// - key: A key string.
55+
/// - value: A value string.
56+
/// - Returns: A `DataFrameReader`.
57+
public func option(_ key: String, _ value: String) -> DataFrameReader {
58+
self.extraOptions[key] = value
59+
return self
60+
}
61+
62+
/// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
63+
/// key-value stores).
64+
/// - Returns: A `DataFrame`.
65+
public func load() -> DataFrame {
66+
return load([])
67+
}
68+
69+
/// Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a
70+
/// local or distributed file system).
71+
/// - Parameter path: A path string.
72+
/// - Returns: A `DataFrame`.
73+
public func load(_ path: String) -> DataFrame {
74+
return load([path])
75+
}
76+
77+
/// Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if
78+
/// the source is a HadoopFsRelationProvider.
79+
/// - Parameter paths: An array of path strings.
80+
/// - Returns: A `DataFrame`.
81+
public func load(_ paths: [String]) -> DataFrame {
82+
self.paths = paths
83+
84+
var dataSource = DataSource()
85+
dataSource.format = self.source
86+
dataSource.paths = self.paths
87+
dataSource.options = self.extraOptions
88+
89+
var read = Read()
90+
read.dataSource = dataSource
91+
92+
var relation = Relation()
93+
relation.read = read
94+
95+
var plan = Plan()
96+
plan.opType = .root(relation)
97+
98+
return DataFrame(spark: sparkSession, plan: plan)
99+
}
100+
101+
/// Loads an Avro file and returns the result as a `DataFrame`.
102+
/// - Parameter path: A path string
103+
/// - Returns: A `DataFrame`.
104+
public func avro(_ path: String) -> DataFrame {
105+
self.source = "avro"
106+
return load(path)
107+
}
108+
109+
/// Loads Avro files and returns the result as a `DataFrame`.
110+
/// - Parameter paths: Path strings
111+
/// - Returns: A `DataFrame`.
112+
public func avro(_ paths: String...) -> DataFrame {
113+
self.source = "avro"
114+
return load(paths)
115+
}
116+
117+
/// Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
118+
/// overloaded `csv()` method for more details.
119+
/// - Parameter path: A path string
120+
/// - Returns: A `DataFrame`.
121+
public func csv(_ path: String) -> DataFrame {
122+
self.source = "csv"
123+
return load(path)
124+
}
125+
126+
/// Loads CSV files and returns the result as a `DataFrame`.
127+
/// This function will go through the input once to determine the input schema if `inferSchema`
128+
/// is enabled. To avoid going through the entire data once, disable `inferSchema` option or
129+
/// specify the schema explicitly using `schema`.
130+
/// - Parameter paths: Path strings.
131+
/// - Returns: A `DataFrame`.
132+
public func csv(_ paths: String...) -> DataFrame {
133+
self.source = "csv"
134+
return load(paths)
135+
}
136+
137+
/// Loads a JSON file and returns the result as a `DataFrame`.
138+
/// - Parameter path: A path string
139+
/// - Returns: A `DataFrame`.
140+
public func json(_ path: String) -> DataFrame {
141+
self.source = "json"
142+
return load(path)
143+
}
144+
145+
/// Loads JSON files and returns the result as a `DataFrame`.
146+
/// - Parameter paths: Path strings
147+
/// - Returns: A `DataFrame`.
148+
public func json(_ paths: String...) -> DataFrame {
149+
self.source = "json"
150+
return load(paths)
151+
}
152+
153+
/// Loads an ORC file and returns the result as a `DataFrame`.
154+
/// - Parameter path: A path string
155+
/// - Returns: A `DataFrame`.
156+
public func orc(_ path: String) -> DataFrame {
157+
self.source = "orc"
158+
return load(path)
159+
}
160+
161+
/// Loads ORC files and returns the result as a `DataFrame`.
162+
/// - Parameter paths: Path strings
163+
/// - Returns: A `DataFrame`.
164+
public func orc(_ paths: String...) -> DataFrame {
165+
self.source = "orc"
166+
return load(paths)
167+
}
168+
169+
/// Loads a Parquet file and returns the result as a `DataFrame`.
170+
/// - Parameter path: A path string
171+
/// - Returns: A `DataFrame`.
172+
public func parquet(_ path: String) -> DataFrame {
173+
self.source = "parquet"
174+
return load(path)
175+
}
176+
177+
/// Loads Parquet files, returning the result as a `DataFrame`.
178+
/// - Parameter paths: Path strings
179+
/// - Returns: A `DataFrame`.
180+
public func parquet(_ paths: String...) -> DataFrame {
181+
self.source = "parquet"
182+
return load(paths)
183+
}
184+
}

Sources/SparkConnect/SparkSession.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ public actor SparkSession {
116116
return try await DataFrame(spark: self, sqlText: sqlText)
117117
}
118118

119+
var read: DataFrameReader {
120+
get {
121+
return DataFrameReader(sparkSession: self)
122+
}
123+
}
124+
119125
/// This is defined as the return type of `SparkSession.sparkContext` method.
120126
/// This is an empty `Struct` type because `sparkContext` method is designed to throw
121127
/// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`.

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
2020
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
2121
typealias ConfigRequest = Spark_Connect_ConfigRequest
22+
typealias DataSource = Spark_Connect_Read.DataSource
2223
typealias DataType = Spark_Connect_DataType
2324
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
2425
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
@@ -29,6 +30,7 @@ typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
2930
typealias Plan = Spark_Connect_Plan
3031
typealias Project = Spark_Connect_Project
3132
typealias Range = Spark_Connect_Range
33+
typealias Read = Spark_Connect_Read
3234
typealias Relation = Spark_Connect_Relation
3335
typealias SparkConnectService = Spark_Connect_SparkConnectService
3436
typealias Sort = Spark_Connect_Sort
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import Testing
22+
23+
@testable import SparkConnect
24+
25+
/// A test suite for `DataFrameReader`
26+
struct DataFrameReaderTests {
27+
28+
@Test
29+
func avro() async throws {
30+
let spark = try await SparkSession.builder.getOrCreate()
31+
let path = "../examples/src/main/resources/users.avro"
32+
#expect(try await spark.read.format("avro").load(path).count() == 2)
33+
#expect(try await spark.read.avro(path).count() == 2)
34+
#expect(try await spark.read.avro(path, path).count() == 4)
35+
await spark.stop()
36+
}
37+
38+
@Test
39+
func csv() async throws {
40+
let spark = try await SparkSession.builder.getOrCreate()
41+
let path = "../examples/src/main/resources/people.csv"
42+
#expect(try await spark.read.format("csv").load(path).count() == 3)
43+
#expect(try await spark.read.csv(path).count() == 3)
44+
#expect(try await spark.read.csv(path, path).count() == 6)
45+
await spark.stop()
46+
}
47+
48+
@Test
49+
func json() async throws {
50+
let spark = try await SparkSession.builder.getOrCreate()
51+
let path = "../examples/src/main/resources/people.json"
52+
#expect(try await spark.read.format("json").load(path).count() == 3)
53+
#expect(try await spark.read.json(path).count() == 3)
54+
#expect(try await spark.read.json(path, path).count() == 6)
55+
await spark.stop()
56+
}
57+
58+
@Test
59+
func orc() async throws {
60+
let spark = try await SparkSession.builder.getOrCreate()
61+
let path = "../examples/src/main/resources/users.orc"
62+
#expect(try await spark.read.format("orc").load(path).count() == 2)
63+
#expect(try await spark.read.orc(path).count() == 2)
64+
#expect(try await spark.read.orc(path, path).count() == 4)
65+
await spark.stop()
66+
}
67+
68+
@Test
69+
func parquet() async throws {
70+
let spark = try await SparkSession.builder.getOrCreate()
71+
let path = "../examples/src/main/resources/users.parquet"
72+
#expect(try await spark.read.format("parquet").load(path).count() == 2)
73+
#expect(try await spark.read.parquet(path).count() == 2)
74+
#expect(try await spark.read.parquet(path, path).count() == 4)
75+
await spark.stop()
76+
}
77+
}

0 commit comments

Comments
 (0)