@@ -7,8 +7,24 @@ public enum DataLoaderValue<T: Sendable>: Sendable {
7
7
case failure( Error )
8
8
}
9
9
10
+ actor Concurrent < T> {
11
+ var wrappedValue : T
12
+
13
+ func nonmutating< Returned> ( _ action: ( T ) throws -> Returned ) async rethrows -> Returned {
14
+ try action ( wrappedValue)
15
+ }
16
+
17
+ func mutating< Returned> ( _ action: ( inout T ) throws -> Returned ) async rethrows -> Returned {
18
+ try action ( & wrappedValue)
19
+ }
20
+
21
+ init ( _ value: T ) {
22
+ wrappedValue = value
23
+ }
24
+ }
25
+
10
26
public typealias BatchLoadFunction < Key: Hashable & Sendable , Value: Sendable > = @Sendable ( _ keys: [ Key ] ) async throws -> [ DataLoaderValue < Value > ]
11
- private typealias LoaderQueue < Key: Hashable & Sendable , Value: Sendable > = [ ( key: Key , channel: AsyncThrowingChannel < Value , Error > ) ]
27
+ private typealias LoaderQueue < Key: Hashable & Sendable , Value: Sendable > = [ ( key: Key , channel: Channel < Value , Error > ) ]
12
28
13
29
/// DataLoader creates a public API for loading data from a particular
14
30
/// data back-end with unique keys such as the id column of a SQL table
@@ -22,7 +38,7 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
22
38
private let batchLoadFunction : BatchLoadFunction < Key , Value >
23
39
private let options : DataLoaderOptions < Key , Value >
24
40
25
- private var cache = [ Key: Value] ( )
41
+ private var cache = [ Key : Channel < Value , Error > ] ( )
26
42
private var queue = LoaderQueue < Key , Value > ( )
27
43
28
44
private var dispatchScheduled = false
@@ -40,10 +56,10 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
40
56
let cacheKey = options. cacheKeyFunction ? ( key) ?? key
41
57
42
58
if options. cachingEnabled, let cached = cache [ cacheKey] {
43
- return cached
59
+ return try await cached. value
44
60
}
45
61
46
- let channel = AsyncThrowingChannel < Value , Error > ( )
62
+ let channel = Channel < Value , Error > ( )
47
63
48
64
if options. batchingEnabled {
49
65
queue. append ( ( key: key, channel: channel) )
@@ -59,38 +75,27 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
59
75
do {
60
76
let results = try await self . batchLoadFunction ( [ key] )
61
77
if results. isEmpty {
62
- channel. fail ( DataLoaderError . noValueForKey ( " Did not return value for key: \( key) " ) )
78
+ await channel. fail ( with : DataLoaderError . noValueForKey ( " Did not return value for key: \( key) " ) )
63
79
} else {
64
80
let result = results [ 0 ]
65
81
switch result {
66
82
case let . success( value) :
67
- await channel. send ( value)
68
- channel. finish ( )
83
+ await channel. fulfill ( with: value)
69
84
case let . failure( error) :
70
- channel. fail ( error)
85
+ await channel. fail ( with : error)
71
86
}
72
87
}
73
88
} catch {
74
- channel. fail ( error)
89
+ await channel. fail ( with : error)
75
90
}
76
91
}
77
92
}
78
93
79
- var value : Value ?
80
-
81
- for try await channelResult in channel {
82
- value = channelResult
83
- }
84
-
85
- guard let value else {
86
- throw DataLoaderError . noValueForKey ( " Did not return value for key: \( key) " )
87
- }
88
-
89
94
if options. cachingEnabled {
90
- cache [ cacheKey] = value
95
+ cache [ cacheKey] = channel
91
96
}
92
97
93
- return value
98
+ return try await channel . value
94
99
}
95
100
96
101
/// Loads multiple keys, promising an array of values:
@@ -146,16 +151,12 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
146
151
let cacheKey = options. cacheKeyFunction ? ( key) ?? key
147
152
148
153
if cache [ cacheKey] == nil {
149
- let channel = AsyncThrowingChannel < Value , Error > ( )
154
+ let channel = Channel < Value , Error > ( )
150
155
Task . detached {
151
- await channel. send ( value)
152
-
153
- channel. finish ( )
156
+ await channel. fulfill ( with: value)
154
157
}
155
158
156
- for try await channelResult in channel {
157
- cache [ cacheKey] = channelResult
158
- }
159
+ cache [ cacheKey] = channel
159
160
}
160
161
161
162
return self
@@ -204,21 +205,98 @@ public actor DataLoader<Key: Hashable & Sendable, Value: Sendable> {
204
205
205
206
switch result {
206
207
case let . failure( error) :
207
- entry. element. channel. fail ( error)
208
+ await entry. element. channel. fail ( with : error)
208
209
case let . success( value) :
209
- await entry. element. channel. send ( value)
210
- entry. element. channel. finish ( )
210
+ await entry. element. channel. fulfill ( with: value)
211
211
}
212
212
}
213
213
} catch {
214
- failedExecution ( batch: batch, error: error)
214
+ await failedExecution ( batch: batch, error: error)
215
215
}
216
216
}
217
217
218
- private func failedExecution( batch: LoaderQueue < Key , Value > , error: Error ) {
218
+ private func failedExecution( batch: LoaderQueue < Key , Value > , error: Error ) async {
219
219
for (key, channel) in batch {
220
220
_ = clear ( key: key)
221
- channel. fail ( error)
221
+ await channel. fail ( with: error)
222
+ }
223
+ }
224
+ }
225
+
226
+ public actor Channel< Success: Sendable , Failure: Error > : Sendable {
227
+ typealias Waiter = CheckedContinuation < Success , Error >
228
+
229
+ private actor State {
230
+ var waiters = [ Waiter] ( )
231
+ var result : Success ? = nil
232
+ var failure : Failure ? = nil
233
+
234
+ func setResult( result: Success ) {
235
+ self . result = result
236
+ }
237
+
238
+ func setFailure( failure: Failure ) {
239
+ self . failure = failure
240
+ }
241
+
242
+ func appendWaiters( waiters: Waiter ... ) {
243
+ self . waiters. append ( contentsOf: waiters)
222
244
}
245
+
246
+ func removeAllWaiters( ) {
247
+ self . waiters. removeAll ( )
248
+ }
249
+ }
250
+
251
+ private var state = State ( )
252
+
253
+ public init ( _ elementType: Success . Type = Success . self) { }
254
+
255
+ @discardableResult
256
+ public func fulfill( with value: Success ) async -> Bool {
257
+ if await state. result == nil {
258
+ await state. setResult ( result: value)
259
+ for waiters in await state. waiters {
260
+ waiters. resume ( returning: value)
261
+ }
262
+ await state. removeAllWaiters ( )
263
+ return false
264
+ }
265
+ return true
266
+ }
267
+
268
+ @discardableResult
269
+ public func fail( with failure: Failure) async - > Bool {
270
+ if await state. failure == nil {
271
+ await state. setFailure ( failure: failure)
272
+ for waiters in await state. waiters {
273
+ waiters. resume ( throwing: failure)
274
+ }
275
+ await state. removeAllWaiters ( )
276
+ return false
277
+ }
278
+ return true
279
+ }
280
+
281
+ public var value: Success {
282
+ get async throws {
283
+ try await withCheckedThrowingContinuation { continuation in
284
+ Task {
285
+ if let result = await state. result {
286
+ continuation. resume ( returning: result)
287
+ } else if let failure = await self . state. failure {
288
+ continuation. resume ( throwing: failure)
289
+ } else {
290
+ await state. appendWaiters ( waiters: continuation)
291
+ }
292
+ }
293
+ }
294
+ }
295
+ }
296
+ }
297
+
298
+ extension Channel where Success == Void {
299
+ func fulfill( ) async -> Bool {
300
+ return await fulfill ( with: ( ) )
223
301
}
224
302
}
0 commit comments