Skip to content

Commit c8de4d0

Browse files
committed
Fix the remaining todo on hard cancellation vs soft cancellation
1 parent 8447eef commit c8de4d0

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

Sources/AsyncAlgorithms/AsyncShareSequence.swift

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
7575
var buffer = [Element]()
7676
var finished = false
7777
var failure: Failure?
78+
var cancelled = false
7879
var limit: CheckedContinuation<Bool, Never>?
7980
var demand: CheckedContinuation<Void, Never>?
8081

@@ -155,14 +156,17 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
155156
}
156157

157158
func cancel() {
158-
// TODO: this currently is a hard cancel, it should be refined to only cancel when everything is terminal
159159
let (task, limit, demand, cancelled) = state.withLock { state -> (IteratingTask?, CheckedContinuation<Bool, Never>?, CheckedContinuation<Void, Never>?, Bool) in
160-
defer {
161-
state.iteratingTask = .cancelled
162-
state.limit = nil
163-
state.demand = nil
160+
if state.sides.count == 0 {
161+
defer {
162+
state.iteratingTask = .cancelled
163+
state.cancelled = true
164+
}
165+
return state.emit(state.iteratingTask)
166+
} else {
167+
state.cancelled = true
168+
return state.emit(nil)
164169
}
165-
return state.emit(state.iteratingTask)
166170
}
167171
task?.cancel()
168172
limit?.resume(returning: cancelled)
@@ -178,21 +182,32 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
178182
}
179183

180184
func unregisterSide(_ id: Int) {
181-
let (side, continuation, cancelled) = state.withLock { state -> (Side.State?, CheckedContinuation<Bool, Never>?, Bool) in
185+
let (side, continuation, cancelled, iteratingTaskToCancel) = state.withLock { state -> (Side.State?, CheckedContinuation<Bool, Never>?, Bool, IteratingTask?) in
182186
let side = state.sides.removeValue(forKey: id)
183187
state.trimBuffer()
188+
let cancelRequested = state.sides.count == 0 && state.cancelled
184189
if let limit, state.buffer.count < limit {
185190
defer { state.limit = nil }
186191
if case .cancelled = state.iteratingTask {
187-
return (side, state.limit, true)
192+
return (side, state.limit, true, nil)
188193
} else {
189-
return (side, state.limit, false)
194+
defer {
195+
if cancelRequested {
196+
state.iteratingTask = .cancelled
197+
}
198+
}
199+
return (side, state.limit, false, cancelRequested ? state.iteratingTask : nil)
190200
}
191201
} else {
192202
if case .cancelled = state.iteratingTask {
193-
return (side, nil, true)
203+
return (side, nil, true, nil)
194204
} else {
195-
return (side, nil, false)
205+
defer {
206+
if cancelRequested {
207+
state.iteratingTask = .cancelled
208+
}
209+
}
210+
return (side, nil, false, cancelRequested ? state.iteratingTask : nil)
196211
}
197212
}
198213
}
@@ -202,6 +217,9 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
202217
if let side {
203218
side.continuaton?.resume(returning: .success(nil))
204219
}
220+
if let iteratingTaskToCancel {
221+
iteratingTaskToCancel.cancel()
222+
}
205223
}
206224

207225
func iterate() async -> Bool {

0 commit comments

Comments
 (0)