Skip to content

Commit c20a75b

Browse files
authored
Allow cancellation to be propagated to child coroutines (#721)
1 parent e504a0e commit c20a75b

File tree

6 files changed

+346
-15
lines changed

6 files changed

+346
-15
lines changed

strings/base_coroutine_foundation.h

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,21 @@ namespace winrt::impl
128128
};
129129

130130
template <typename Async>
131-
struct await_adapter
131+
struct await_adapter : enable_await_cancellation
132132
{
133+
await_adapter(Async const& async) : async(async) { }
134+
133135
Async const& async;
134136
Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started;
135137

138+
void enable_cancellation(cancellable_promise* promise)
139+
{
140+
promise->set_canceller([](void* parameter)
141+
{
142+
cancel_asynchronously(reinterpret_cast<await_adapter*>(parameter)->async);
143+
}, this);
144+
}
145+
136146
bool await_ready() const noexcept
137147
{
138148
return false;
@@ -153,6 +163,19 @@ namespace winrt::impl
153163
check_status_canceled(status);
154164
return async.GetResults();
155165
}
166+
167+
private:
168+
static fire_and_forget cancel_asynchronously(Async async)
169+
{
170+
co_await winrt::resume_background();
171+
try
172+
{
173+
async.Cancel();
174+
}
175+
catch (hresult_error const&)
176+
{
177+
}
178+
}
156179
};
157180

158181
template <typename D>
@@ -278,6 +301,11 @@ namespace winrt::impl
278301
m_promise->cancellation_callback(std::move(cancel));
279302
}
280303

304+
bool enable_propagation(bool value = true) const noexcept
305+
{
306+
return m_promise->enable_cancellation_propagation(value);
307+
}
308+
281309
private:
282310

283311
Promise* m_promise;
@@ -414,6 +442,8 @@ namespace winrt::impl
414442
{
415443
cancel();
416444
}
445+
446+
m_cancellable.cancel();
417447
}
418448

419449
void Close() const noexcept
@@ -536,7 +566,7 @@ namespace winrt::impl
536566
throw winrt::hresult_canceled();
537567
}
538568

539-
return notify_awaiter<Expression>{ static_cast<Expression&&>(expression) };
569+
return notify_awaiter<Expression>{ static_cast<Expression&&>(expression), m_propagate_cancellation ? &m_cancellable : nullptr };
540570
}
541571

542572
cancellation_token<Derived> await_transform(get_cancellation_token_t) noexcept
@@ -567,6 +597,11 @@ namespace winrt::impl
567597
}
568598
}
569599

600+
bool enable_cancellation_propagation(bool value) noexcept
601+
{
602+
return std::exchange(m_propagate_cancellation, value);
603+
}
604+
570605
#if defined(_DEBUG) && !defined(WINRT_NO_MAKE_DETECTION)
571606
void use_make_function_to_create_this_object() final
572607
{
@@ -587,8 +622,10 @@ namespace winrt::impl
587622
slim_mutex m_lock;
588623
async_completed_handler_t<AsyncInterface> m_completed;
589624
winrt::delegate<> m_cancel;
625+
cancellable_promise m_cancellable;
590626
std::atomic<AsyncStatus> m_status;
591627
bool m_completed_assigned{ false };
628+
bool m_propagate_cancellation{ false };
592629
};
593630
}
594631

strings/base_coroutine_threadpool.h

Lines changed: 162 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,81 @@ namespace winrt::impl
124124
static constexpr bool has_co_await_member = find_co_await_member<T&&>(0);
125125
static constexpr bool has_co_await_free = find_co_await_free<T&&>(0);
126126
};
127+
}
128+
129+
WINRT_EXPORT namespace winrt
130+
{
131+
struct cancellable_promise
132+
{
133+
using canceller_t = void(*)(void*);
134+
135+
void set_canceller(canceller_t canceller, void* context)
136+
{
137+
m_context = context;
138+
m_canceller.store(canceller, std::memory_order_release);
139+
}
140+
141+
void revoke_canceller()
142+
{
143+
while (m_canceller.exchange(nullptr, std::memory_order_acquire) == cancelling_ptr)
144+
{
145+
std::this_thread::yield();
146+
}
147+
}
148+
149+
void cancel()
150+
{
151+
auto canceller = m_canceller.exchange(cancelling_ptr, std::memory_order_acquire);
152+
struct unique_cancellation_lock
153+
{
154+
cancellable_promise* promise;
155+
~unique_cancellation_lock()
156+
{
157+
promise->m_canceller.store(nullptr, std::memory_order_release);
158+
}
159+
} lock{ this };
160+
161+
if ((canceller != nullptr) && (canceller != cancelling_ptr))
162+
{
163+
canceller(m_context);
164+
}
165+
}
166+
167+
private:
168+
static inline auto const cancelling_ptr = reinterpret_cast<canceller_t>(1);
169+
170+
std::atomic<canceller_t> m_canceller{ nullptr };
171+
void* m_context{ nullptr };
172+
};
173+
174+
struct enable_await_cancellation
175+
{
176+
enable_await_cancellation() noexcept = default;
177+
enable_await_cancellation(enable_await_cancellation const&) = delete;
178+
179+
~enable_await_cancellation()
180+
{
181+
if (m_promise)
182+
{
183+
m_promise->revoke_canceller();
184+
}
185+
}
186+
187+
void operator=(enable_await_cancellation const&) = delete;
188+
189+
void set_cancellable_promise(cancellable_promise* promise) noexcept
190+
{
191+
m_promise = promise;
192+
}
193+
194+
private:
195+
196+
cancellable_promise* m_promise = nullptr;
197+
};
198+
}
127199

200+
namespace winrt::impl
201+
{
128202
template <typename T>
129203
decltype(auto) get_awaiter(T&& value) noexcept
130204
{
@@ -149,8 +223,16 @@ namespace winrt::impl
149223
{
150224
decltype(get_awaiter(std::declval<T&&>())) awaitable;
151225

152-
notify_awaiter(T&& awaitable) : awaitable(get_awaiter(static_cast<T&&>(awaitable)))
226+
notify_awaiter(T&& awaitable_arg, cancellable_promise* promise = nullptr) : awaitable(get_awaiter(static_cast<T&&>(awaitable_arg)))
153227
{
228+
if constexpr (std::is_convertible_v<std::remove_reference_t<decltype(awaitable)>&, enable_await_cancellation&>)
229+
{
230+
if (promise)
231+
{
232+
static_cast<enable_await_cancellation&>(awaitable).set_cancellable_promise(promise);
233+
awaitable.enable_cancellation(promise);
234+
}
235+
}
154236
}
155237

156238
bool await_ready()
@@ -271,34 +353,67 @@ WINRT_EXPORT namespace winrt
271353

272354
[[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept
273355
{
274-
struct awaitable
356+
struct awaitable : enable_await_cancellation
275357
{
276358
explicit awaitable(Windows::Foundation::TimeSpan duration) noexcept :
277359
m_duration(duration)
278360
{
279361
}
280362

363+
void enable_cancellation(cancellable_promise* promise)
364+
{
365+
promise->set_canceller([](void* context)
366+
{
367+
auto that = static_cast<awaitable*>(context);
368+
if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending)
369+
{
370+
that->fire_immediately();
371+
}
372+
}, this);
373+
}
374+
281375
bool await_ready() const noexcept
282376
{
283377
return m_duration.count() <= 0;
284378
}
285379

286380
void await_suspend(std::experimental::coroutine_handle<> handle)
287381
{
288-
m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, handle.address(), nullptr)));
382+
m_handle = handle;
383+
m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr)));
289384
int64_t relative_count = -m_duration.count();
290-
WINRT_IMPL_SetThreadpoolTimer(m_timer.get(), &relative_count, 0, 0);
385+
WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &relative_count, 0, 0);
386+
387+
state expected = state::idle;
388+
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
389+
{
390+
fire_immediately();
391+
}
291392
}
292393

293-
void await_resume() const noexcept
394+
void await_resume()
294395
{
396+
if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled)
397+
{
398+
throw hresult_canceled();
399+
}
295400
}
296401

297402
private:
298403

404+
void fire_immediately() noexcept
405+
{
406+
if (WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), nullptr, 0, 0))
407+
{
408+
int64_t now = 0;
409+
WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &now, 0, 0);
410+
}
411+
}
412+
299413
static void __stdcall callback(void*, void* context, void*) noexcept
300414
{
301-
std::experimental::coroutine_handle<>::from_address(context)();
415+
auto that = reinterpret_cast<awaitable*>(context);
416+
that->m_handle();
302417
}
303418

304419
struct timer_traits
@@ -316,8 +431,12 @@ WINRT_EXPORT namespace winrt
316431
}
317432
};
318433

434+
enum class state { idle, pending, canceled };
435+
319436
handle_type<timer_traits> m_timer;
320437
Windows::Foundation::TimeSpan m_duration;
438+
std::experimental::coroutine_handle<> m_handle;
439+
std::atomic<state> m_state{ state::idle };
321440
};
322441

323442
return awaitable{ duration };
@@ -332,13 +451,25 @@ WINRT_EXPORT namespace winrt
332451

333452
[[nodiscard]] inline auto resume_on_signal(void* handle, Windows::Foundation::TimeSpan timeout = {}) noexcept
334453
{
335-
struct awaitable
454+
struct awaitable : enable_await_cancellation
336455
{
337456
awaitable(void* handle, Windows::Foundation::TimeSpan timeout) noexcept :
338457
m_timeout(timeout),
339458
m_handle(handle)
340459
{}
341460

461+
void enable_cancellation(cancellable_promise* promise)
462+
{
463+
promise->set_canceller([](void* context)
464+
{
465+
auto that = static_cast<awaitable*>(context);
466+
if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending)
467+
{
468+
that->fire_immediately();
469+
}
470+
}, this);
471+
}
472+
342473
bool await_ready() const noexcept
343474
{
344475
return WINRT_IMPL_WaitForSingleObject(m_handle, 0) == 0;
@@ -350,16 +481,35 @@ WINRT_EXPORT namespace winrt
350481
m_wait.attach(check_pointer(WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr)));
351482
int64_t relative_count = -m_timeout.count();
352483
int64_t* file_time = relative_count != 0 ? &relative_count : nullptr;
353-
WINRT_IMPL_SetThreadpoolWait(m_wait.get(), m_handle, file_time);
484+
WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), m_handle, file_time, nullptr);
485+
486+
state expected = state::idle;
487+
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
488+
{
489+
fire_immediately();
490+
}
354491
}
355492

356-
bool await_resume() const noexcept
493+
bool await_resume()
357494
{
495+
if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled)
496+
{
497+
throw hresult_canceled();
498+
}
358499
return m_result == 0;
359500
}
360501

361502
private:
362503

504+
void fire_immediately() noexcept
505+
{
506+
if (WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), nullptr, nullptr, nullptr))
507+
{
508+
int64_t now = 0;
509+
WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), WINRT_IMPL_GetCurrentProcess(), &now, nullptr);
510+
}
511+
}
512+
363513
static void __stdcall callback(void*, void* context, void*, uint32_t result) noexcept
364514
{
365515
auto that = static_cast<awaitable*>(context);
@@ -382,11 +532,14 @@ WINRT_EXPORT namespace winrt
382532
}
383533
};
384534

535+
enum class state { idle, pending, canceled };
536+
385537
handle_type<wait_traits> m_wait;
386538
Windows::Foundation::TimeSpan m_timeout;
387539
void* m_handle;
388540
uint32_t m_result{};
389541
std::experimental::coroutine_handle<> m_resume{ nullptr };
542+
std::atomic<state> m_state{ state::idle };
390543
};
391544

392545
return awaitable{ handle, timeout };

strings/base_extern.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ extern "C"
6363

6464
int32_t __stdcall WINRT_IMPL_TrySubmitThreadpoolCallback(void(__stdcall *callback)(void*, void* context), void* context, void*) noexcept;
6565
winrt::impl::ptp_timer __stdcall WINRT_IMPL_CreateThreadpoolTimer(void(__stdcall *callback)(void*, void* context, void*), void* context, void*) noexcept;
66-
void __stdcall WINRT_IMPL_SetThreadpoolTimer(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept;
66+
int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept;
6767
void __stdcall WINRT_IMPL_CloseThreadpoolTimer(winrt::impl::ptp_timer timer) noexcept;
6868
winrt::impl::ptp_wait __stdcall WINRT_IMPL_CreateThreadpoolWait(void(__stdcall *callback)(void*, void* context, void*, uint32_t result), void* context, void*) noexcept;
69-
void __stdcall WINRT_IMPL_SetThreadpoolWait(winrt::impl::ptp_wait wait, void* handle, void* timeout) noexcept;
69+
int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait wait, void* handle, void* timeout, void* reserved) noexcept;
7070
void __stdcall WINRT_IMPL_CloseThreadpoolWait(winrt::impl::ptp_wait wait) noexcept;
7171
winrt::impl::ptp_io __stdcall WINRT_IMPL_CreateThreadpoolIo(void* object, void(__stdcall *callback)(void*, void* context, void* overlapped, uint32_t result, std::size_t bytes, void*) noexcept, void* context, void*) noexcept;
7272
void __stdcall WINRT_IMPL_StartThreadpoolIo(winrt::impl::ptp_io io) noexcept;
@@ -147,10 +147,10 @@ WINRT_IMPL_LINK(WaitForSingleObject, 8)
147147

148148
WINRT_IMPL_LINK(TrySubmitThreadpoolCallback, 12)
149149
WINRT_IMPL_LINK(CreateThreadpoolTimer, 12)
150-
WINRT_IMPL_LINK(SetThreadpoolTimer, 16)
150+
WINRT_IMPL_LINK(SetThreadpoolTimerEx, 16)
151151
WINRT_IMPL_LINK(CloseThreadpoolTimer, 4)
152152
WINRT_IMPL_LINK(CreateThreadpoolWait, 12)
153-
WINRT_IMPL_LINK(SetThreadpoolWait, 12)
153+
WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16)
154154
WINRT_IMPL_LINK(CloseThreadpoolWait, 4)
155155
WINRT_IMPL_LINK(CreateThreadpoolIo, 16)
156156
WINRT_IMPL_LINK(StartThreadpoolIo, 4)

strings/base_includes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <stdexcept>
1313
#include <string_view>
1414
#include <string>
15+
#include <thread>
1516
#include <tuple>
1617
#include <type_traits>
1718
#include <unordered_map>

0 commit comments

Comments
 (0)