@@ -75,6 +75,7 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
75
75
var buffer = [ Element] ( )
76
76
var finished = false
77
77
var failure : Failure ?
78
+ var cancelled = false
78
79
var limit : CheckedContinuation < Bool , Never > ?
79
80
var demand : CheckedContinuation < Void , Never > ?
80
81
@@ -155,14 +156,17 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
155
156
}
156
157
157
158
func cancel( ) {
158
- // TODO: this currently is a hard cancel, it should be refined to only cancel when everything is terminal
159
159
let ( task, limit, demand, cancelled) = state. withLock { state -> ( IteratingTask ? , CheckedContinuation < Bool , Never > ? , CheckedContinuation < Void , Never > ? , Bool ) in
160
- defer {
161
- state. iteratingTask = . cancelled
162
- state. limit = nil
163
- state. demand = nil
160
+ if state. sides. count == 0 {
161
+ defer {
162
+ state. iteratingTask = . cancelled
163
+ state. cancelled = true
164
+ }
165
+ return state. emit ( state. iteratingTask)
166
+ } else {
167
+ state. cancelled = true
168
+ return state. emit ( nil )
164
169
}
165
- return state. emit ( state. iteratingTask)
166
170
}
167
171
task? . cancel ( )
168
172
limit? . resume ( returning: cancelled)
@@ -178,21 +182,32 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
178
182
}
179
183
180
184
func unregisterSide( _ id: Int ) {
181
- let ( side, continuation, cancelled) = state. withLock { state -> ( Side . State ? , CheckedContinuation < Bool , Never > ? , Bool ) in
185
+ let ( side, continuation, cancelled, iteratingTaskToCancel ) = state. withLock { state -> ( Side . State ? , CheckedContinuation < Bool , Never > ? , Bool , IteratingTask ? ) in
182
186
let side = state. sides. removeValue ( forKey: id)
183
187
state. trimBuffer ( )
188
+ let cancelRequested = state. sides. count == 0 && state. cancelled
184
189
if let limit, state. buffer. count < limit {
185
190
defer { state. limit = nil }
186
191
if case . cancelled = state. iteratingTask {
187
- return ( side, state. limit, true )
192
+ return ( side, state. limit, true , nil )
188
193
} else {
189
- return ( side, state. limit, false )
194
+ defer {
195
+ if cancelRequested {
196
+ state. iteratingTask = . cancelled
197
+ }
198
+ }
199
+ return ( side, state. limit, false , cancelRequested ? state. iteratingTask : nil )
190
200
}
191
201
} else {
192
202
if case . cancelled = state. iteratingTask {
193
- return ( side, nil , true )
203
+ return ( side, nil , true , nil )
194
204
} else {
195
- return ( side, nil , false )
205
+ defer {
206
+ if cancelRequested {
207
+ state. iteratingTask = . cancelled
208
+ }
209
+ }
210
+ return ( side, nil , false , cancelRequested ? state. iteratingTask : nil )
196
211
}
197
212
}
198
213
}
@@ -202,6 +217,9 @@ struct AsyncShareSequence<Base: AsyncSequence>: Sendable where Base.Element: Sen
202
217
if let side {
203
218
side. continuaton? . resume ( returning: . success( nil ) )
204
219
}
220
+ if let iteratingTaskToCancel {
221
+ iteratingTaskToCancel. cancel ( )
222
+ }
205
223
}
206
224
207
225
func iterate( ) async -> Bool {
0 commit comments