@@ -124,7 +124,81 @@ namespace winrt::impl
124124 static constexpr bool has_co_await_member = find_co_await_member<T&&>(0 );
125125 static constexpr bool has_co_await_free = find_co_await_free<T&&>(0 );
126126 };
127+ }
128+
129+ WINRT_EXPORT namespace winrt
130+ {
131+ struct cancellable_promise
132+ {
133+ using canceller_t = void (*)(void *);
134+
135+ void set_canceller (canceller_t canceller, void * context)
136+ {
137+ m_context = context;
138+ m_canceller.store (canceller, std::memory_order_release);
139+ }
140+
141+ void revoke_canceller ()
142+ {
143+ while (m_canceller.exchange (nullptr , std::memory_order_acquire) == cancelling_ptr)
144+ {
145+ std::this_thread::yield ();
146+ }
147+ }
148+
149+ void cancel ()
150+ {
151+ auto canceller = m_canceller.exchange (cancelling_ptr, std::memory_order_acquire);
152+ struct unique_cancellation_lock
153+ {
154+ cancellable_promise* promise;
155+ ~unique_cancellation_lock ()
156+ {
157+ promise->m_canceller .store (nullptr , std::memory_order_release);
158+ }
159+ } lock{ this };
160+
161+ if ((canceller != nullptr ) && (canceller != cancelling_ptr))
162+ {
163+ canceller (m_context);
164+ }
165+ }
166+
167+ private:
168+ static inline auto const cancelling_ptr = reinterpret_cast <canceller_t >(1 );
169+
170+ std::atomic<canceller_t > m_canceller{ nullptr };
171+ void * m_context{ nullptr };
172+ };
173+
174+ struct enable_await_cancellation
175+ {
176+ enable_await_cancellation () noexcept = default ;
177+ enable_await_cancellation (enable_await_cancellation const &) = delete ;
178+
179+ ~enable_await_cancellation ()
180+ {
181+ if (m_promise)
182+ {
183+ m_promise->revoke_canceller ();
184+ }
185+ }
186+
187+ void operator =(enable_await_cancellation const &) = delete ;
188+
189+ void set_cancellable_promise (cancellable_promise* promise) noexcept
190+ {
191+ m_promise = promise;
192+ }
193+
194+ private:
195+
196+ cancellable_promise* m_promise = nullptr ;
197+ };
198+ }
127199
200+ namespace winrt ::impl
201+ {
128202 template <typename T>
129203 decltype (auto ) get_awaiter(T&& value) noexcept
130204 {
@@ -149,8 +223,16 @@ namespace winrt::impl
149223 {
150224 decltype (get_awaiter(std::declval<T&&>())) awaitable;
151225
152- notify_awaiter (T&& awaitable ) : awaitable(get_awaiter(static_cast <T&&>(awaitable )))
226+ notify_awaiter (T&& awaitable_arg, cancellable_promise* promise = nullptr ) : awaitable(get_awaiter(static_cast <T&&>(awaitable_arg )))
153227 {
228+ if constexpr (std::is_convertible_v<std::remove_reference_t <decltype (awaitable)>&, enable_await_cancellation&>)
229+ {
230+ if (promise)
231+ {
232+ static_cast <enable_await_cancellation&>(awaitable).set_cancellable_promise (promise);
233+ awaitable.enable_cancellation (promise);
234+ }
235+ }
154236 }
155237
156238 bool await_ready ()
@@ -271,34 +353,67 @@ WINRT_EXPORT namespace winrt
271353
272354 [[nodiscard]] inline auto resume_after (Windows::Foundation::TimeSpan duration) noexcept
273355 {
274- struct awaitable
356+ struct awaitable : enable_await_cancellation
275357 {
276358 explicit awaitable (Windows::Foundation::TimeSpan duration) noexcept :
277359 m_duration(duration)
278360 {
279361 }
280362
363+ void enable_cancellation (cancellable_promise* promise)
364+ {
365+ promise->set_canceller ([](void * context)
366+ {
367+ auto that = static_cast <awaitable*>(context);
368+ if (that->m_state .exchange (state::canceled, std::memory_order_acquire) == state::pending)
369+ {
370+ that->fire_immediately ();
371+ }
372+ }, this );
373+ }
374+
281375 bool await_ready () const noexcept
282376 {
283377 return m_duration.count () <= 0 ;
284378 }
285379
286380 void await_suspend (std::experimental::coroutine_handle<> handle)
287381 {
288- m_timer.attach (check_pointer (WINRT_IMPL_CreateThreadpoolTimer (callback, handle.address (), nullptr )));
382+ m_handle = handle;
383+ m_timer.attach (check_pointer (WINRT_IMPL_CreateThreadpoolTimer (callback, this , nullptr )));
289384 int64_t relative_count = -m_duration.count ();
290- WINRT_IMPL_SetThreadpoolTimer (m_timer.get (), &relative_count, 0 , 0 );
385+ WINRT_IMPL_SetThreadpoolTimerEx (m_timer.get (), &relative_count, 0 , 0 );
386+
387+ state expected = state::idle;
388+ if (!m_state.compare_exchange_strong (expected, state::pending, std::memory_order_release))
389+ {
390+ fire_immediately ();
391+ }
291392 }
292393
293- void await_resume () const noexcept
394+ void await_resume ()
294395 {
396+ if (m_state.exchange (state::idle, std::memory_order_relaxed) == state::canceled)
397+ {
398+ throw hresult_canceled ();
399+ }
295400 }
296401
297402 private:
298403
404+ void fire_immediately () noexcept
405+ {
406+ if (WINRT_IMPL_SetThreadpoolTimerEx (m_timer.get (), nullptr , 0 , 0 ))
407+ {
408+ int64_t now = 0 ;
409+ WINRT_IMPL_SetThreadpoolTimerEx (m_timer.get (), &now, 0 , 0 );
410+ }
411+ }
412+
299413 static void __stdcall callback (void *, void * context, void *) noexcept
300414 {
301- std::experimental::coroutine_handle<>::from_address (context)();
415+ auto that = reinterpret_cast <awaitable*>(context);
416+ that->m_handle ();
302417 }
303418
304419 struct timer_traits
@@ -316,8 +431,12 @@ WINRT_EXPORT namespace winrt
316431 }
317432 };
318433
434+ enum class state { idle, pending, canceled };
435+
319436 handle_type<timer_traits> m_timer;
320437 Windows::Foundation::TimeSpan m_duration;
438+ std::experimental::coroutine_handle<> m_handle;
439+ std::atomic<state> m_state{ state::idle };
321440 };
322441
323442 return awaitable{ duration };
@@ -332,13 +451,25 @@ WINRT_EXPORT namespace winrt
332451
333452 [[nodiscard]] inline auto resume_on_signal (void * handle, Windows::Foundation::TimeSpan timeout = {}) noexcept
334453 {
335- struct awaitable
454+ struct awaitable : enable_await_cancellation
336455 {
337456 awaitable (void * handle, Windows::Foundation::TimeSpan timeout) noexcept :
338457 m_timeout (timeout),
339458 m_handle (handle)
340459 {}
341460
461+ void enable_cancellation (cancellable_promise* promise)
462+ {
463+ promise->set_canceller ([](void * context)
464+ {
465+ auto that = static_cast <awaitable*>(context);
466+ if (that->m_state .exchange (state::canceled, std::memory_order_acquire) == state::pending)
467+ {
468+ that->fire_immediately ();
469+ }
470+ }, this );
471+ }
472+
342473 bool await_ready () const noexcept
343474 {
344475 return WINRT_IMPL_WaitForSingleObject (m_handle, 0 ) == 0 ;
@@ -350,16 +481,35 @@ WINRT_EXPORT namespace winrt
350481 m_wait.attach (check_pointer (WINRT_IMPL_CreateThreadpoolWait (callback, this , nullptr )));
351482 int64_t relative_count = -m_timeout.count ();
352483 int64_t * file_time = relative_count != 0 ? &relative_count : nullptr ;
353- WINRT_IMPL_SetThreadpoolWait (m_wait.get (), m_handle, file_time);
484+ WINRT_IMPL_SetThreadpoolWaitEx (m_wait.get (), m_handle, file_time, nullptr );
485+
486+ state expected = state::idle;
487+ if (!m_state.compare_exchange_strong (expected, state::pending, std::memory_order_release))
488+ {
489+ fire_immediately ();
490+ }
354491 }
355492
356- bool await_resume () const noexcept
493+ bool await_resume ()
357494 {
495+ if (m_state.exchange (state::idle, std::memory_order_relaxed) == state::canceled)
496+ {
497+ throw hresult_canceled ();
498+ }
358499 return m_result == 0 ;
359500 }
360501
361502 private:
362503
504+ void fire_immediately () noexcept
505+ {
506+ if (WINRT_IMPL_SetThreadpoolWaitEx (m_wait.get (), nullptr , nullptr , nullptr ))
507+ {
508+ int64_t now = 0 ;
509+ WINRT_IMPL_SetThreadpoolWaitEx (m_wait.get (), WINRT_IMPL_GetCurrentProcess (), &now, nullptr );
510+ }
511+ }
512+
363513 static void __stdcall callback (void *, void * context, void *, uint32_t result) noexcept
364514 {
365515 auto that = static_cast <awaitable*>(context);
@@ -382,11 +532,14 @@ WINRT_EXPORT namespace winrt
382532 }
383533 };
384534
535+ enum class state { idle, pending, canceled };
536+
385537 handle_type<wait_traits> m_wait;
386538 Windows::Foundation::TimeSpan m_timeout;
387539 void * m_handle;
388540 uint32_t m_result{};
389541 std::experimental::coroutine_handle<> m_resume{ nullptr };
542+ std::atomic<state> m_state{ state::idle };
390543 };
391544
392545 return awaitable{ handle, timeout };
0 commit comments