Skip to content

Commit 82fa8ec

Browse files
authored
Improve async cancellation interop (#643)
1 parent 78d212a commit 82fa8ec

File tree

2 files changed

+143
-26
lines changed

2 files changed

+143
-26
lines changed

strings/base_coroutine_foundation.h

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,32 @@ namespace winrt::impl
3838
WINRT_ASSERT(!is_sta());
3939
}
4040

41-
template <typename Async>
42-
void wait_for_completed(Async const& async, uint32_t const timeout)
41+
template <typename T, typename H>
42+
std::pair<T, H*> make_delegate_with_shared_state(H&& handler)
4343
{
44-
void* event = check_pointer(WINRT_IMPL_CreateEventW(nullptr, true, false, nullptr));
44+
auto d = make_delegate<T, H>(std::forward<H>(handler));
45+
return { std::move(d), reinterpret_cast<delegate<T, H>*>(get_abi(d)) };
46+
}
4547

46-
// The delegate is a local to ensure that the event outlives the call to WaitForSingleObject.
47-
async_completed_handler_t<Async> delegate = [event = handle(event)](auto && ...)
48+
template <typename Async>
49+
auto wait_for_completed(Async const& async, uint32_t const timeout)
50+
{
51+
struct shared_type
4852
{
49-
WINRT_VERIFY(WINRT_IMPL_SetEvent(event.get()));
53+
handle event{ check_pointer(WINRT_IMPL_CreateEventW(nullptr, true, false, nullptr)) };
54+
Windows::Foundation::AsyncStatus status{ Windows::Foundation::AsyncStatus::Started };
55+
56+
void operator()(Async const&, Windows::Foundation::AsyncStatus operation_status) noexcept
57+
{
58+
status = operation_status;
59+
WINRT_VERIFY(WINRT_IMPL_SetEvent(event.get()));
60+
}
5061
};
5162

63+
auto [delegate, shared] = make_delegate_with_shared_state<async_completed_handler_t<Async>>(shared_type{});
5264
async.Completed(delegate);
53-
WINRT_IMPL_WaitForSingleObject(event, timeout);
65+
WINRT_IMPL_WaitForSingleObject(shared->event.get(), timeout);
66+
return shared->status;
5467
}
5568

5669
template <typename Async>
@@ -59,19 +72,28 @@ namespace winrt::impl
5972
check_sta_blocking_wait();
6073
auto const milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(timeout).count();
6174
WINRT_ASSERT((milliseconds >= 0) && (static_cast<uint64_t>(milliseconds) < 0xFFFFFFFFull)); // Within uint32_t range and not INFINITE
62-
wait_for_completed(async, static_cast<uint32_t>(milliseconds));
63-
return async.Status();
75+
return wait_for_completed(async, static_cast<uint32_t>(milliseconds));
76+
}
77+
78+
inline void check_status_canceled(Windows::Foundation::AsyncStatus status)
79+
{
80+
if (status == Windows::Foundation::AsyncStatus::Canceled)
81+
{
82+
throw hresult_canceled();
83+
}
6484
}
6585

6686
template <typename Async>
6787
auto wait_get(Async const& async)
6888
{
6989
check_sta_blocking_wait();
7090

71-
if (async.Status() == Windows::Foundation::AsyncStatus::Started)
91+
auto status = async.Status();
92+
if (status == Windows::Foundation::AsyncStatus::Started)
7293
{
73-
wait_for_completed(async, 0xFFFFFFFF); // INFINITE
94+
status = wait_for_completed(async, 0xFFFFFFFF); // INFINITE
7495
}
96+
check_status_canceled(status);
7597

7698
return async.GetResults();
7799
}
@@ -90,7 +112,7 @@ namespace winrt::impl
90112
if (m_handle) Complete();
91113
}
92114

93-
void operator()(Windows::Foundation::IAsyncInfo const&, Windows::Foundation::AsyncStatus)
115+
void operator()()
94116
{
95117
Complete();
96118
}
@@ -109,19 +131,25 @@ namespace winrt::impl
109131
struct await_adapter
110132
{
111133
Async const& async;
134+
Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started;
112135

113136
bool await_ready() const noexcept
114137
{
115138
return false;
116139
}
117140

118-
void await_suspend(std::experimental::coroutine_handle<> handle) const
141+
void await_suspend(std::experimental::coroutine_handle<> handle)
119142
{
120-
async.Completed(disconnect_aware_handler{ handle });
143+
async.Completed([this, handler = disconnect_aware_handler{ handle }](auto&&, auto operation_status) mutable
144+
{
145+
status = operation_status;
146+
handler();
147+
});
121148
}
122149

123150
auto await_resume() const
124151
{
152+
check_status_canceled(status);
125153
return async.GetResults();
126154
}
127155
};
@@ -691,28 +719,33 @@ WINRT_EXPORT namespace winrt
691719
struct shared_type
692720
{
693721
handle event{ check_pointer(WINRT_IMPL_CreateEventW(nullptr, true, false, nullptr)) };
722+
Windows::Foundation::AsyncStatus status{ Windows::Foundation::AsyncStatus::Started };
694723
T result;
724+
725+
void operator()(T const& sender, Windows::Foundation::AsyncStatus operation_status) noexcept
726+
{
727+
auto sender_abi = *(impl::unknown_abi**)&sender;
728+
729+
if (nullptr == _InterlockedCompareExchangePointer(reinterpret_cast<void**>(&result), sender_abi, nullptr))
730+
{
731+
sender_abi->AddRef();
732+
status = operation_status;
733+
WINRT_VERIFY(WINRT_IMPL_SetEvent(event.get()));
734+
}
735+
}
695736
};
696737

697-
auto shared = std::make_shared<shared_type>();
738+
auto [delegate, shared] = impl::make_delegate_with_shared_state<impl::async_completed_handler_t<T>>(shared_type{});
698739

699740
auto completed = [&](T const& async)
700741
{
701-
async.Completed([shared](T const& sender, Windows::Foundation::AsyncStatus) noexcept
702-
{
703-
auto sender_abi = *(impl::unknown_abi**)&sender;
704-
705-
if (nullptr == _InterlockedCompareExchangePointer(reinterpret_cast<void**>(&shared->result), sender_abi, nullptr))
706-
{
707-
sender_abi->AddRef();
708-
WINRT_VERIFY(WINRT_IMPL_SetEvent(shared->event.get()));
709-
}
710-
});
742+
async.Completed(delegate);
711743
};
712744

713745
completed(first);
714746
(completed(rest), ...);
715747
co_await resume_on_signal(shared->event.get());
748+
impl::check_status_canceled(shared->status);
716749
co_return shared->result.GetResults();
717750
}
718751
}

test/old_tests/UnitTests/async_cancel.cpp

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,60 @@ namespace
5858
REQUIRE(!signaled(callback));
5959
co_await resume_on_signal(end.get());
6060
}
61+
62+
// Other projections report cancellation via the Completed handler and Status,
63+
// rather than via ErrorCode and GetResults. Verify we interop cancellation properly.
64+
template<typename T, typename Async>
65+
struct foreign_canceled_async : implements<T, Async, IAsyncInfo>
66+
{
67+
template<typename Handler>
68+
void Completed(Handler&& complete)
69+
{
70+
complete(*this, AsyncStatus::Canceled);
71+
}
72+
73+
auto Completed() const noexcept
74+
{
75+
return nullptr;
76+
}
77+
78+
uint32_t Id() const noexcept
79+
{
80+
return 1;
81+
}
82+
83+
AsyncStatus Status() const noexcept
84+
{
85+
return AsyncStatus::Canceled;
86+
}
87+
88+
hresult ErrorCode() const noexcept
89+
{
90+
return impl::error_illegal_method_call;
91+
}
92+
93+
decltype(std::declval<Async>().GetResults()) GetResults() const
94+
{
95+
throw_hresult(ErrorCode());
96+
}
97+
98+
void Cancel() const noexcept
99+
{
100+
}
101+
102+
void Close() const noexcept
103+
{
104+
}
105+
};
106+
107+
struct foreign_canceled_action : foreign_canceled_async<foreign_canceled_action, IAsyncAction>
108+
{
109+
};
110+
111+
template<typename TResult>
112+
struct foreign_canceled_operation : foreign_canceled_async<foreign_canceled_operation<TResult>, IAsyncOperation<TResult>>
113+
{
114+
};
61115
}
62116

63117
TEST_CASE("async_cancel_no_async")
@@ -98,4 +152,34 @@ TEST_CASE("async_cancel_after_callback")
98152
SetEvent(end.get());
99153
wait(callback);
100154
REQUIRE(async.Status() == AsyncStatus::Canceled);
101-
}
155+
}
156+
157+
TEST_CASE("async_cancel_use_status")
158+
{
159+
// Validate that co_await preserves cancellation.
160+
handle complete{ CreateEvent(nullptr, true, false, nullptr) };
161+
[](void* complete) -> fire_and_forget
162+
{
163+
REQUIRE_THROWS_AS(co_await make<foreign_canceled_action>(), hresult_canceled);
164+
REQUIRE_THROWS_AS(co_await make<foreign_canceled_operation<int32_t>>(), hresult_canceled);
165+
166+
REQUIRE_THROWS_AS(co_await when_any(make<foreign_canceled_action>(), make<foreign_canceled_action>()), hresult_canceled);
167+
REQUIRE_THROWS_AS(co_await when_any(make<foreign_canceled_operation<int>>(), make<foreign_canceled_operation<int>>()), hresult_canceled);
168+
169+
REQUIRE_THROWS_AS(co_await when_all(make<foreign_canceled_action>(), make<foreign_canceled_action>()), hresult_canceled);
170+
REQUIRE_THROWS_AS(co_await when_all(make<foreign_canceled_operation<int>>(), make<foreign_canceled_operation<int>>()), hresult_canceled);
171+
172+
SetEvent(complete);
173+
}(complete.get());
174+
WaitForSingleObject(complete.get(), INFINITE);
175+
176+
// Validate that get() preserves cancellation.
177+
REQUIRE_THROWS_AS(make<foreign_canceled_action>().get(), hresult_canceled);
178+
REQUIRE_THROWS_AS(make<foreign_canceled_operation<int>>().get(), hresult_canceled);
179+
180+
REQUIRE_THROWS_AS(when_any(make<foreign_canceled_action>(), make<foreign_canceled_action>()).get(), hresult_canceled);
181+
REQUIRE_THROWS_AS(when_any(make<foreign_canceled_operation<int>>(), make<foreign_canceled_operation<int>>()).get(), hresult_canceled);
182+
183+
REQUIRE_THROWS_AS(when_all(make<foreign_canceled_action>(), make<foreign_canceled_action>()).get(), hresult_canceled);
184+
REQUIRE_THROWS_AS(when_all(make<foreign_canceled_operation<int>>(), make<foreign_canceled_operation<int>>()).get(), hresult_canceled);
185+
}

0 commit comments

Comments
 (0)