Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,10 @@ public actor DataFrame: Sendable {
return try await persist()
}

public func persist(
useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false,
deserialized: Bool = true, replication: Int32 = 1
)
async throws -> DataFrame
/// Persist this `DataFrame` with the given storage level.
/// - Parameter storageLevel: A storage level to apply.
public func persist(storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) async throws
-> DataFrame
{
try await withGRPCClient(
transport: .http2NIOPosix(
Expand All @@ -308,8 +307,7 @@ public actor DataFrame: Sendable {
) { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
_ = try await service.analyzePlan(
spark.client.getPersist(
spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, replication))
spark.client.getPersist(spark.sessionID, plan, storageLevel))
}

return self
Expand Down
13 changes: 2 additions & 11 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -256,23 +256,14 @@ public actor SparkConnectClient {
return request
}

func getPersist(
_ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: Bool = true,
_ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: Int32 = 1
) async
func getPersist(_ sessionID: String, _ plan: Plan, _ storageLevel: StorageLevel) async
-> AnalyzePlanRequest
{
return analyze(
sessionID,
{
var persist = AnalyzePlanRequest.Persist()
var level = StorageLevel()
level.useDisk = useDisk
level.useMemory = useMemory
level.useOffHeap = useOffHeap
level.deserialized = deserialized
level.replication = replication
persist.storageLevel = level
persist.storageLevel = storageLevel.toSparkConnectStorageLevel
persist.relation = plan.root
return OneOf_Analyze.persist(persist)
})
Expand Down
88 changes: 88 additions & 0 deletions Sources/SparkConnect/StorageLevel.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//

/// Flags for controlling the storage of an `RDD`. Each ``StorageLevel`` records whether to use memory,
/// or `ExternalBlockStore`, whether to drop the `RDD` to disk if it falls out of memory or
/// `ExternalBlockStore`, whether to keep the data in memory in a serialized format, and whether
/// to replicate the `RDD` partitions on multiple nodes.
public struct StorageLevel: Sendable {
/// Whether the cache should use disk or not.
var useDisk: Bool

/// Whether the cache should use memory or not.
var useMemory: Bool

/// Whether the cache should use off-heap or not.
var useOffHeap: Bool

/// Whether the cached data is deserialized or not.
var deserialized: Bool

/// The number of replicas.
var replication: Int32

init(useDisk: Bool, useMemory: Bool, useOffHeap: Bool, deserialized: Bool, replication: Int32 = 1)
{
self.useDisk = useDisk
self.useMemory = useMemory
self.useOffHeap = useOffHeap
self.deserialized = deserialized
self.replication = replication
}

public static let NONE = StorageLevel(
useDisk: false, useMemory: false, useOffHeap: false, deserialized: false)
public static let DISK_ONLY = StorageLevel(
useDisk: true, useMemory: false, useOffHeap: false, deserialized: false)
public static let DISK_ONLY_2 = StorageLevel(
useDisk: true, useMemory: false, useOffHeap: false, deserialized: false, replication: 2)
public static let DISK_ONLY_3 = StorageLevel(
useDisk: true, useMemory: false, useOffHeap: false, deserialized: false, replication: 3)
public static let MEMORY_ONLY = StorageLevel(
useDisk: false, useMemory: true, useOffHeap: false, deserialized: false)
public static let MEMORY_ONLY_2 = StorageLevel(
useDisk: false, useMemory: true, useOffHeap: false, deserialized: false, replication: 2)
public static let MEMORY_AND_DISK = StorageLevel(
useDisk: true, useMemory: true, useOffHeap: false, deserialized: false)
public static let MEMORY_AND_DISK_2 = StorageLevel(
useDisk: true, useMemory: true, useOffHeap: false, deserialized: false, replication: 2)
public static let OFF_HEAP = StorageLevel(
useDisk: true, useMemory: true, useOffHeap: true, deserialized: false)
public static let MEMORY_AND_DISK_DESER = StorageLevel(
useDisk: true, useMemory: true, useOffHeap: false, deserialized: true)
}

extension StorageLevel {
var toSparkConnectStorageLevel: Spark_Connect_StorageLevel {
var level = Spark_Connect_StorageLevel()
level.useDisk = self.useDisk
level.useMemory = self.useMemory
level.useOffHeap = self.useOffHeap
level.deserialized = self.deserialized
level.replication = self.replication
return level
}
}

extension StorageLevel: CustomStringConvertible {
public var description: String {
return
"StorageLevel(useDisk: \(useDisk), useMemory: \(useMemory), useOffHeap: \(useOffHeap), deserialized: \(deserialized), replication: \(replication))"
}
}
1 change: 0 additions & 1 deletion Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,5 @@ typealias Read = Spark_Connect_Read
typealias Relation = Spark_Connect_Relation
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
typealias StorageLevel = Spark_Connect_StorageLevel
typealias UserContext = Spark_Connect_UserContext
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
6 changes: 4 additions & 2 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,17 @@ struct DataFrameTests {
func persist() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(20).persist().count() == 20)
#expect(try await spark.range(21).persist(useDisk: false).count() == 21)
#expect(try await spark.range(21).persist(storageLevel: StorageLevel.MEMORY_ONLY).count() == 21)
await spark.stop()
}

@Test
func persistInvalidStorageLevel() async throws {
let spark = try await SparkSession.builder.getOrCreate()
try await #require(throws: Error.self) {
let _ = try await spark.range(9999).persist(replication: 0).count()
var invalidLevel = StorageLevel.DISK_ONLY
invalidLevel.replication = 0
let _ = try await spark.range(9999).persist(storageLevel: invalidLevel).count()
}
await spark.stop()
}
Expand Down
Loading