Skip to content

Commit dfa3717

Browse files
committed
refactor: refactor continuation management to prevent race condition
1 parent 3899792 commit dfa3717

File tree

10 files changed

+225
-372
lines changed

10 files changed

+225
-372
lines changed

Sources/AsyncObjects/AsyncEvent.swift

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Foundation
99
public actor AsyncEvent: AsyncObject {
1010
/// The suspended tasks continuation type.
1111
private typealias Continuation = GlobalContinuation<Void, Error>
12-
/// The continuations stored with an associated key for all the suspended task that are waitig for event signal.
12+
/// The continuations stored with an associated key for all the suspended task that are waiting for event signal.
1313
private var continuations: [UUID: Continuation] = [:]
1414
/// Indicates whether current stateof event is signaled.
1515
private var signaled: Bool
@@ -28,12 +28,35 @@ public actor AsyncEvent: AsyncObject {
2828
}
2929

3030
/// Remove continuation associated with provided key
31-
/// from `continuations` map.
31+
/// from `continuations` map and resumes with `CancellationError`.
3232
///
3333
/// - Parameter key: The key in the map.
3434
@inline(__always)
3535
private func removeContinuation(withKey key: UUID) {
36-
continuations.removeValue(forKey: key)
36+
let continuation = continuations.removeValue(forKey: key)
37+
continuation?.resume(throwing: CancellationError())
38+
}
39+
40+
/// Suspends the current task, then calls the given closure with a throwing continuation for the current task.
41+
/// Continuation can be cancelled with error if current task is cancelled, by invoking `removeContinuation`.
42+
///
43+
/// Spins up a new continuation and requests to track it with key by invoking `addContinuation`.
44+
/// This operation cooperatively checks for cancellation and reacting to it by invoking `removeContinuation`.
45+
/// Continuation can be resumed with error and some cleanup code can be run here.
46+
///
47+
/// - Throws: If `resume(throwing:)` is called on the continuation, this function throws that error.
48+
@inline(__always)
49+
private func withPromisedContinuation() async throws {
50+
let key = UUID()
51+
try await withTaskCancellationHandler { [weak self] in
52+
Task { [weak self] in
53+
await self?.removeContinuation(withKey: key)
54+
}
55+
} operation: { () -> Continuation.Success in
56+
try await Continuation.with { continuation in
57+
self.addContinuation(continuation, withKey: key)
58+
}
59+
}
3760
}
3861

3962
/// Creates a new event with signal state provided.
@@ -73,18 +96,6 @@ public actor AsyncEvent: AsyncObject {
7396
@Sendable
7497
public func wait() async {
7598
guard !signaled else { return }
76-
let key = UUID()
77-
try? await withThrowingContinuationCancellationHandler(
78-
handler: { [weak self] continuation in
79-
Task { [weak self] in
80-
await self?.removeContinuation(withKey: key)
81-
}
82-
},
83-
{ [weak self] continuation in
84-
Task { [weak self] in
85-
await self?.addContinuation(continuation, withKey: key)
86-
}
87-
}
88-
)
99+
try? await withPromisedContinuation()
89100
}
90101
}

Sources/AsyncObjects/AsyncObject.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import Foundation
2+
13
/// A result value indicating whether a task finished before a specified time.
24
@frozen
35
public enum TaskTimeoutResult: Hashable {
@@ -70,6 +72,7 @@ public extension AsyncObject where Self: AnyObject {
7072
/// and returns only when all the invokation completes.
7173
///
7274
/// - Parameter objects: The objects to wait for.
75+
@inlinable
7376
public func waitForAll(_ objects: [any AsyncObject]) async {
7477
await withTaskGroup(of: Void.self) { group in
7578
objects.forEach { group.addTask(operation: $0.wait) }
@@ -83,6 +86,7 @@ public func waitForAll(_ objects: [any AsyncObject]) async {
8386
/// and returns only when all the invokation completes.
8487
///
8588
/// - Parameter objects: The objects to wait for.
89+
@inlinable
8690
public func waitForAll(_ objects: any AsyncObject...) async {
8791
await waitForAll(objects)
8892
}
@@ -98,6 +102,7 @@ public func waitForAll(_ objects: any AsyncObject...) async {
98102
/// - objects: The objects to wait for.
99103
/// - duration: The duration in nano seconds to wait until.
100104
/// - Returns: The result indicating whether wait completed or timed out.
105+
@inlinable
101106
public func waitForAll(
102107
_ objects: [any AsyncObject],
103108
forNanoseconds duration: UInt64
@@ -118,6 +123,7 @@ public func waitForAll(
118123
/// - objects: The objects to wait for.
119124
/// - duration: The duration in nano seconds to wait until.
120125
/// - Returns: The result indicating whether wait completed or timed out.
126+
@inlinable
121127
public func waitForAll(
122128
_ objects: any AsyncObject...,
123129
forNanoseconds duration: UInt64
@@ -132,6 +138,7 @@ public func waitForAll(
132138
/// and returns when any of the invokation completes.
133139
///
134140
/// - Parameter objects: The objects to wait for.
141+
@inlinable
135142
public func waitForAny(_ objects: [any AsyncObject]) async {
136143
await withTaskGroup(of: Void.self) { group in
137144
objects.forEach { group.addTask(operation: $0.wait) }
@@ -146,6 +153,7 @@ public func waitForAny(_ objects: [any AsyncObject]) async {
146153
/// and returns when any of the invokation completes.
147154
///
148155
/// - Parameter objects: The objects to wait for.
156+
@inlinable
149157
public func waitForAny(_ objects: any AsyncObject...) async {
150158
await waitForAny(objects)
151159
}
@@ -161,6 +169,7 @@ public func waitForAny(_ objects: any AsyncObject...) async {
161169
/// - objects: The objects to wait for.
162170
/// - duration: The duration in nano seconds to wait until.
163171
/// - Returns: The result indicating whether wait completed or timed out.
172+
@inlinable
164173
public func waitForAny(
165174
_ objects: [any AsyncObject],
166175
forNanoseconds duration: UInt64
@@ -181,6 +190,7 @@ public func waitForAny(
181190
/// - objects: The objects to wait for.
182191
/// - duration: The duration in nano seconds to wait until.
183192
/// - Returns: The result indicating whether wait completed or timed out.
193+
@inlinable
184194
public func waitForAny(
185195
_ objects: any AsyncObject...,
186196
forNanoseconds duration: UInt64
@@ -205,11 +215,11 @@ public func waitForTaskCompletion(
205215
) async -> TaskTimeoutResult {
206216
var timedOut = true
207217
await withTaskGroup(of: Bool.self) { group in
208-
group.addTask {
218+
group.addTask(priority: .high) {
209219
await task()
210220
return !Task.isCancelled
211221
}
212-
group.addTask {
222+
group.addTask(priority: .high) {
213223
(try? await Task.sleep(nanoseconds: timeout)) == nil
214224
}
215225
for await result in group.prefix(1) {

Sources/AsyncObjects/AsyncSemaphore.swift

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import OrderedCollections
1212
public actor AsyncSemaphore: AsyncObject {
1313
/// The suspended tasks continuation type.
1414
private typealias Continuation = GlobalContinuation<Void, Error>
15-
/// The continuations stored with an associated key for all the suspended task that are waitig for access to resource.
15+
/// The continuations stored with an associated key for all the suspended task that are waiting for access to resource.
1616
private var continuations: OrderedDictionary<UUID, Continuation> = [:]
1717
/// Pool size for concurrent resource access.
1818
/// Has value provided during initialization incremented by one.
@@ -35,12 +35,13 @@ public actor AsyncSemaphore: AsyncObject {
3535
}
3636

3737
/// Remove continuation associated with provided key
38-
/// from `continuations` map.
38+
/// from `continuations` map and resumes with `CancellationError`.
3939
///
4040
/// - Parameter key: The key in the map.
4141
@inline(__always)
4242
private func removeContinuation(withKey key: UUID) {
43-
continuations.removeValue(forKey: key)
43+
let continuation = continuations.removeValue(forKey: key)
44+
continuation?.resume(throwing: CancellationError())
4445
incrementCount()
4546
}
4647

@@ -51,6 +52,28 @@ public actor AsyncSemaphore: AsyncObject {
5152
count += 1
5253
}
5354

55+
/// Suspends the current task, then calls the given closure with a throwing continuation for the current task.
56+
/// Continuation can be cancelled with error if current task is cancelled, by invoking `removeContinuation`.
57+
///
58+
/// Spins up a new continuation and requests to track it with key by invoking `addContinuation`.
59+
/// This operation cooperatively checks for cancellation and reacting to it by invoking `removeContinuation`.
60+
/// Continuation can be resumed with error and some cleanup code can be run here.
61+
///
62+
/// - Throws: If `resume(throwing:)` is called on the continuation, this function throws that error.
63+
@inline(__always)
64+
private func withPromisedContinuation() async throws {
65+
let key = UUID()
66+
try await withTaskCancellationHandler { [weak self] in
67+
Task { [weak self] in
68+
await self?.removeContinuation(withKey: key)
69+
}
70+
} operation: { () -> Continuation.Success in
71+
try await Continuation.with { continuation in
72+
self.addContinuation(continuation, withKey: key)
73+
}
74+
}
75+
}
76+
5477
/// Creates new counting semaphore with an initial value.
5578
/// By default, initial value is zero.
5679
///
@@ -88,18 +111,6 @@ public actor AsyncSemaphore: AsyncObject {
88111
public func wait() async {
89112
count -= 1
90113
if count > 0 { return }
91-
let key = UUID()
92-
try? await withThrowingContinuationCancellationHandler(
93-
handler: { [weak self] continuation in
94-
Task { [weak self] in
95-
await self?.removeContinuation(withKey: key)
96-
}
97-
},
98-
{ [weak self] continuation in
99-
Task { [weak self] in
100-
await self?.addContinuation(continuation, withKey: key)
101-
}
102-
}
103-
)
114+
try? await withPromisedContinuation()
104115
}
105116
}

Sources/AsyncObjects/ContinuationWrapper.swift renamed to Sources/AsyncObjects/Continuable.swift

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,3 @@
1-
/// Suspends the current task, then calls the given closure with a throwing continuation for the current task.
2-
/// Continuation is cancelled with error if current task is cancelled and cancellation handler is immediately invoked.
3-
///
4-
/// This operation cooperatively checks for cancellation and reacting to it by cancelling the throwing continuation with an error
5-
/// and the cancellation handler is always and immediately invoked after that.
6-
/// For example, even if the operation is running code that never checks for cancellation,
7-
/// a cancellation handler still runs and provides a chance to run some cleanup code.
8-
///
9-
/// - Parameters:
10-
/// - handler: A closure that is called after cancelling continuation.
11-
/// You must not resume the continuation in closure.
12-
/// - fn: A closure that takes a throwing continuation parameter.
13-
/// You must resume the continuation exactly once.
14-
///
15-
/// - Returns: The value passed to the continuation.
16-
/// - Throws: If `resume(throwing:)` is called on the continuation, this function throws that error.
17-
///
18-
/// - Important: The continuation provided in cancellation handler is already resumed with cancellation error.
19-
/// Trying to resume the continuation here will cause runtime error/unexpected behavior.
20-
func withThrowingContinuationCancellationHandler<C: ThrowingContinuable>(
21-
handler: @Sendable (C) -> Void,
22-
_ fn: (C) -> Void
23-
) async throws -> C.Success where C.Failure == Error {
24-
let wrapper = ContinuationWrapper<C>()
25-
let value = try await withTaskCancellationHandler {
26-
guard let continuation = wrapper.value else { return }
27-
wrapper.cancel(withError: CancellationError())
28-
handler(continuation)
29-
} operation: { () -> C.Success in
30-
let value = try await C.with { continuation in
31-
wrapper.value = continuation
32-
fn(continuation)
33-
}
34-
return value
35-
}
36-
return value
37-
}
38-
39-
/// Wrapper type used to store `continuation` and
40-
/// provide cancellation mechanism.
41-
final class ContinuationWrapper<Wrapped: Continuable> {
42-
/// The underlying continuation referenced.
43-
var value: Wrapped?
44-
45-
/// Creates a new instance with a continuation reference passed.
46-
/// By default no continuation is stored.
47-
///
48-
/// - Parameter value: A continuation reference to store.
49-
///
50-
/// - Returns: The newly created continuation wrapper.
51-
init(value: Wrapped? = nil) {
52-
self.value = value
53-
}
54-
55-
/// Resume continuation with passed error,
56-
/// without checking if continuation already resumed.
57-
///
58-
/// - Parameter error: Error passed to continuation.
59-
func cancel(withError error: Wrapped.Failure) {
60-
value?.resume(throwing: error)
61-
}
62-
}
63-
641
/// A type that allows to interface between synchronous and asynchronous code,
652
/// by representing task state and allowing task resuming with some value or error.
663
protocol Continuable: Sendable {
@@ -151,6 +88,59 @@ extension CheckedContinuation: ThrowingContinuable where E == Error {
15188
}
15289
}
15390

91+
protocol NonThrowingContinuable: Continuable {
92+
/// The type of error to resume the continuation with in case of failure.
93+
associatedtype Failure = Never
94+
/// Suspends the current task, then calls the given closure
95+
/// with a nonthrowing continuation for the current task.
96+
///
97+
/// The continuation can be resumed exactly once,
98+
/// subsequent resumes have different behavior depending on type implemeting.
99+
///
100+
/// - Parameter fn: A closure that takes the nonthrowing continuation parameter.
101+
/// You can resume the continuation exactly once.
102+
///
103+
/// - Returns: The value passed to the continuation by the closure.
104+
@inlinable
105+
static func with(_ fn: (Self) -> Void) async -> Success
106+
}
107+
108+
extension UnsafeContinuation: NonThrowingContinuable where E == Never {
109+
/// Suspends the current task, then calls the given closure
110+
/// with an unsafe nonthrowing continuation for the current task.
111+
///
112+
/// The continuation must be resumed exactly once, subsequent resumes will cause runtime error.
113+
/// Use `CheckedContinuation` to capture relevant data in case of runtime errors.
114+
///
115+
/// - Parameter fn: A closure that takes an `UnsafeContinuation` parameter.
116+
/// You must resume the continuation exactly once.
117+
///
118+
/// - Returns: The value passed to the continuation by the closure.
119+
@inlinable
120+
static func with(_ fn: (UnsafeContinuation<T, E>) -> Void) async -> T {
121+
return await withUnsafeContinuation(fn)
122+
}
123+
}
124+
125+
extension CheckedContinuation: NonThrowingContinuable where E == Never {
126+
/// Suspends the current task, then calls the given closure
127+
/// with a checked nonthrowing continuation for the current task.
128+
///
129+
/// The continuation must be resumed exactly once, subsequent resumes will cause runtime error.
130+
/// `CheckedContinuation` logs messages proving additional info on these errors.
131+
/// Once all errors resolved, use `UnsafeContinuation` in release mode to benefit improved performance
132+
/// at the loss of additional runtime checks.
133+
///
134+
/// - Parameter fn: A closure that takes a `CheckedContinuation` parameter.
135+
/// You must resume the continuation exactly once.
136+
///
137+
/// - Returns: The value passed to the continuation by the closure.
138+
@inlinable
139+
static func with(_ body: (CheckedContinuation<T, E>) -> Void) async -> T {
140+
return await withCheckedContinuation(body)
141+
}
142+
}
143+
154144
#if DEBUG || ASYNCOBJECTS_USE_CHECKEDCONTINUATION
155145
/// The continuation type used in package in `DEBUG` mode
156146
/// or if `ASYNCOBJECTS_USE_CHECKEDCONTINUATION` flag turned on.

0 commit comments

Comments
 (0)