Skip to content

Commit 4f0f66a

Browse files
committed
added: Channel
1 parent a1d17b0 commit 4f0f66a

File tree

1 file changed

+112
-34
lines changed

1 file changed

+112
-34
lines changed

Sources/DataLoader/DataLoader.swift

Lines changed: 112 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,24 @@ public enum DataLoaderValue<T: Sendable>: Sendable {
77
case failure(Error)
88
}
99

10+
actor Concurrent<T> {
11+
var wrappedValue: T
12+
13+
func nonmutating<Returned>(_ action: (T) throws -> Returned) async rethrows -> Returned {
14+
try action(wrappedValue)
15+
}
16+
17+
func mutating<Returned>(_ action: (inout T) throws -> Returned) async rethrows -> Returned {
18+
try action(&wrappedValue)
19+
}
20+
21+
init(_ value: T) {
22+
wrappedValue = value
23+
}
24+
}
25+
1026
public typealias BatchLoadFunction<Key: Hashable & Sendable, Value: Sendable> = @Sendable (_ keys: [Key]) async throws -> [DataLoaderValue<Value>]
11-
private typealias LoaderQueue<Key: Hashable & Sendable, Value: Sendable> = [(key: Key, channel: AsyncThrowingChannel<Value, Error>)]
27+
private typealias LoaderQueue<Key: Hashable & Sendable, Value: Sendable> = [(key: Key, channel: Channel<Value, Error>)]
1228

1329
/// DataLoader creates a public API for loading data from a particular
1430
/// data back-end with unique keys such as the id column of a SQL table
@@ -22,7 +38,7 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
2238
private let batchLoadFunction: BatchLoadFunction<Key, Value>
2339
private let options: DataLoaderOptions<Key, Value>
2440

25-
private var cache = [Key: Value]()
41+
private var cache = [Key: Channel<Value, Error>]()
2642
private var queue = LoaderQueue<Key, Value>()
2743

2844
private var dispatchScheduled = false
@@ -40,10 +56,10 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
4056
let cacheKey = options.cacheKeyFunction?(key) ?? key
4157

4258
if options.cachingEnabled, let cached = cache[cacheKey] {
43-
return cached
59+
return try await cached.value
4460
}
4561

46-
let channel = AsyncThrowingChannel<Value, Error>()
62+
let channel = Channel<Value, Error>()
4763

4864
if options.batchingEnabled {
4965
queue.append((key: key, channel: channel))
@@ -59,38 +75,27 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
5975
do {
6076
let results = try await self.batchLoadFunction([key])
6177
if results.isEmpty {
62-
channel.fail(DataLoaderError.noValueForKey("Did not return value for key: \(key)"))
78+
await channel.fail(with: DataLoaderError.noValueForKey("Did not return value for key: \(key)"))
6379
} else {
6480
let result = results[0]
6581
switch result {
6682
case let .success(value):
67-
await channel.send(value)
68-
channel.finish()
83+
await channel.fulfill(with: value)
6984
case let .failure(error):
70-
channel.fail(error)
85+
await channel.fail(with: error)
7186
}
7287
}
7388
} catch {
74-
channel.fail(error)
89+
await channel.fail(with: error)
7590
}
7691
}
7792
}
7893

79-
var value: Value?
80-
81-
for try await channelResult in channel {
82-
value = channelResult
83-
}
84-
85-
guard let value else {
86-
throw DataLoaderError.noValueForKey("Did not return value for key: \(key)")
87-
}
88-
8994
if options.cachingEnabled {
90-
cache[cacheKey] = value
95+
cache[cacheKey] = channel
9196
}
9297

93-
return value
98+
return try await channel.value
9499
}
95100

96101
/// Loads multiple keys, promising an array of values:
@@ -146,16 +151,12 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
146151
let cacheKey = options.cacheKeyFunction?(key) ?? key
147152

148153
if cache[cacheKey] == nil {
149-
let channel = AsyncThrowingChannel<Value, Error>()
154+
let channel = Channel<Value, Error>()
150155
Task.detached {
151-
await channel.send(value)
152-
153-
channel.finish()
156+
await channel.fulfill(with: value)
154157
}
155158

156-
for try await channelResult in channel {
157-
cache[cacheKey] = channelResult
158-
}
159+
cache[cacheKey] = channel
159160
}
160161

161162
return self
@@ -204,21 +205,98 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
204205

205206
switch result {
206207
case let .failure(error):
207-
entry.element.channel.fail(error)
208+
await entry.element.channel.fail(with: error)
208209
case let .success(value):
209-
await entry.element.channel.send(value)
210-
entry.element.channel.finish()
210+
await entry.element.channel.fulfill(with: value)
211211
}
212212
}
213213
} catch {
214-
failedExecution(batch: batch, error: error)
214+
await failedExecution(batch: batch, error: error)
215215
}
216216
}
217217

218-
private func failedExecution(batch: LoaderQueue<Key, Value>, error: Error) {
218+
private func failedExecution(batch: LoaderQueue<Key, Value>, error: Error) async {
219219
for (key, channel) in batch {
220220
_ = clear(key: key)
221-
channel.fail(error)
221+
await channel.fail(with: error)
222+
}
223+
}
224+
}
225+
226+
public actor Channel<Success: Sendable, Failure: Error>: Sendable {
227+
typealias Waiter = CheckedContinuation<Success, Error>
228+
229+
private actor State {
230+
var waiters = [Waiter]()
231+
var result: Success? = nil
232+
var failure: Failure? = nil
233+
234+
func setResult(result: Success) {
235+
self.result = result
236+
}
237+
238+
func setFailure(failure: Failure) {
239+
self.failure = failure
240+
}
241+
242+
func appendWaiters(waiters: Waiter...) {
243+
self.waiters.append(contentsOf: waiters)
222244
}
245+
246+
func removeAllWaiters() {
247+
self.waiters.removeAll()
248+
}
249+
}
250+
251+
private var state = State()
252+
253+
public init(_ elementType: Success.Type = Success.self) {}
254+
255+
@discardableResult
256+
public func fulfill(with value: Success) async -> Bool {
257+
if await state.result == nil {
258+
await state.setResult(result:value)
259+
for waiters in await state.waiters {
260+
waiters.resume(returning: value)
261+
}
262+
await state.removeAllWaiters()
263+
return false
264+
}
265+
return true
266+
}
267+
268+
@discardableResult
269+
public func fail(with failure: Failure) async -> Bool {
270+
if await state.failure == nil {
271+
await state.setFailure(failure: failure)
272+
for waiters in await state.waiters {
273+
waiters.resume(throwing: failure)
274+
}
275+
await state.removeAllWaiters()
276+
return false
277+
}
278+
return true
279+
}
280+
281+
public var value: Success {
282+
get async throws {
283+
try await withCheckedThrowingContinuation { continuation in
284+
Task {
285+
if let result = await state.result {
286+
continuation.resume(returning: result)
287+
} else if let failure = await self.state.failure {
288+
continuation.resume(throwing: failure)
289+
} else {
290+
await state.appendWaiters(waiters: continuation)
291+
}
292+
}
293+
}
294+
}
295+
}
296+
}
297+
298+
extension Channel where Success == Void {
299+
func fulfill() async -> Bool {
300+
return await fulfill(with: ())
223301
}
224302
}

0 commit comments

Comments
 (0)