Skip to content

Commit 24650de

Browse files
authored
Fix weak references to coroutines (#1097)
1 parent 1ccfe2e commit 24650de

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

strings/base_implements.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ namespace winrt::impl
831831
virtual ~root_implements() noexcept
832832
{
833833
// If a weak reference is created during destruction, this ensures that it is also destroyed.
834-
subtract_reference();
834+
subtract_final_reference();
835835
}
836836

837837
int32_t __stdcall GetIids(uint32_t* count, guid** array) noexcept
@@ -897,10 +897,6 @@ namespace winrt::impl
897897

898898
if (target == 0)
899899
{
900-
// If a weak reference was previously created, the m_references value will not be stable value (won't be zero).
901-
// This ensures destruction has a stable value during destruction.
902-
m_references = 1;
903-
904900
if constexpr (has_final_release::value)
905901
{
906902
D::final_release(std::unique_ptr<D>(static_cast<D*>(this)));
@@ -992,7 +988,7 @@ namespace winrt::impl
992988
}
993989
catch (...) { return to_hresult(); }
994990

995-
uint32_t subtract_reference() noexcept
991+
uint32_t subtract_final_reference() noexcept
996992
{
997993
if constexpr (is_weak_ref_source::value)
998994
{
@@ -1019,6 +1015,19 @@ namespace winrt::impl
10191015
}
10201016
}
10211017

1018+
uint32_t subtract_reference() noexcept
1019+
{
1020+
uint32_t result = subtract_final_reference();
1021+
1022+
if (result == 0)
1023+
{
1024+
// Ensure destruction happens with a stable reference count that isn't a weak reference.
1025+
m_references.store(1, std::memory_order_relaxed);
1026+
}
1027+
1028+
return result;
1029+
}
1030+
10221031
template <typename T>
10231032
winrt::weak_ref<T> get_weak()
10241033
{

test/old_tests/UnitTests/weak.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,49 @@ namespace
6868
REQUIRE(weak_self.get() == nullptr);
6969
}
7070
};
71+
72+
struct WeakCreateWeakInDestructor : implements<WeakCreateWeakInDestructor, IStringable>
73+
{
74+
winrt::weak_ref<WeakCreateWeakInDestructor>& weak_self;
75+
76+
WeakCreateWeakInDestructor(winrt::weak_ref<WeakCreateWeakInDestructor>& magic) : weak_self(magic) {}
77+
78+
~WeakCreateWeakInDestructor()
79+
{
80+
// Creates a weak reference to itself in the destructor.
81+
weak_self = get_weak();
82+
}
83+
84+
hstring ToString()
85+
{
86+
return L"WeakCreateWeakInDestructor";
87+
}
88+
};
89+
90+
#ifdef WINRT_IMPL_COROUTINES
91+
// Returns an IAsyncAction that has already completed.
92+
winrt::Windows::Foundation::IAsyncAction Action()
93+
{
94+
co_return;
95+
}
96+
97+
// Returns an IAsyncAction that has not completed.
98+
// Call the resume() handle to complete it.
99+
winrt::Windows::Foundation::IAsyncAction SuspendAction(impl::coroutine_handle<>& resume)
100+
{
101+
struct awaiter
102+
{
103+
impl::coroutine_handle<>& resume;
104+
bool await_ready() { return false; }
105+
void await_suspend(impl::coroutine_handle<> handle) { resume = handle; }
106+
void await_resume() {}
107+
};
108+
109+
co_await awaiter{ resume };
110+
co_return;
111+
}
112+
113+
#endif
71114
}
72115

73116
TEST_CASE("weak,source")
@@ -413,3 +456,37 @@ TEST_CASE("weak,self")
413456
a.ToString();
414457
a = nullptr;
415458
}
459+
460+
TEST_CASE("weak,create_weak_in_destructor")
461+
{
462+
weak_ref<WeakCreateWeakInDestructor> magic;
463+
IStringable a = make<WeakCreateWeakInDestructor>(magic);
464+
a.ToString();
465+
a = nullptr;
466+
REQUIRE(magic.get() == nullptr);
467+
}
468+
469+
#ifdef WINRT_IMPL_COROUTINES
470+
TEST_CASE("weak,coroutine")
471+
{
472+
// Run a coroutine to completion. Confirm that weak references fail to resolve.
473+
auto weak = winrt::weak_ref(Action());
474+
REQUIRE(weak.get() == nullptr);
475+
476+
// Start a coroutine but don't complete it yet.
477+
// Confirm that weak references resolve.
478+
impl::coroutine_handle<> resume;
479+
weak = winrt::weak_ref(SuspendAction(resume));
480+
REQUIRE(weak.get() != nullptr);
481+
// Now complete the coroutine. Confirm that weak references no longer resolve.
482+
resume();
483+
REQUIRE(weak.get() == nullptr);
484+
485+
// Verify that weak reference resolves as long as strong reference exists.
486+
auto action = Action();
487+
weak = winrt::weak_ref(action);
488+
REQUIRE(weak.get() == action);
489+
action = nullptr;
490+
REQUIRE(weak.get() == nullptr);
491+
}
492+
#endif

0 commit comments

Comments
 (0)