Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/common_build_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ endmacro()

# Fixup default compiler settings
if (MSVC)
# Ensure exception handling is enabled; some CMake versions don't add /EHsc by default for all
# MSVC-compatible compilers (e.g. clang-cl). Adding it to CMAKE_CXX_FLAGS allows targets that
# need to disable exceptions (e.g. noexcept) to remove it via replace_cxx_flag.
if (NOT CMAKE_CXX_FLAGS MATCHES "/EHsc")
string(APPEND CMAKE_CXX_FLAGS " /EHsc")
endif()

add_compile_options(
# Be as strict as reasonably possible, since we want to support consumers using strict warning levels
/W4 /WX
Expand Down
17 changes: 17 additions & 0 deletions include/wil/Tracelogging.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,23 @@ class ActivityBase : public details::IFailureCallback
}
}

bool IsWatching() WI_NOEXCEPT
{
return m_callbackHolder.IsWatching();
}

// Coroutine watcher interface: pause watching during suspension
bool suspend() WI_NOEXCEPT
{
return m_callbackHolder.suspend();
}

// Coroutine watcher interface: resume watching after suspension
void resume() WI_NOEXCEPT
{
m_callbackHolder.resume();
}

// Call this API to retrieve an RAII object to watch events on the current thread. The returned
// object should only be used on the stack.

Expand Down
88 changes: 88 additions & 0 deletions include/wil/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,75 @@ struct task_base

static void __stdcall wake_by_address(void* completed);
};

// Generic awaitable wrapper that suspends/resumes a watcher across co_await.
// TPausable must provide:
// bool suspend() - called before suspension, returns true if resume() should be called
// void resume() - called after resumption if suspend() returned true
template <typename TPausable, typename TChildAwaitable>
struct coroutine_withsuspend_awaiter
{
TPausable& pausable;
TChildAwaitable child_awaitable;
bool resume_needed = false;

bool await_ready() noexcept
{
return child_awaitable.await_ready();
}

template <typename T>
auto await_suspend(T&& handle) noexcept(noexcept(wistd::declval<TChildAwaitable>().await_suspend(wistd::forward<T>(handle))) && noexcept(pausable.suspend()))
{
resume_needed = pausable.suspend();
return child_awaitable.await_suspend(wistd::forward<T>(handle));
}

auto await_resume() noexcept(noexcept(wistd::declval<TChildAwaitable>().await_resume()))
{
if (resume_needed)
{
pausable.resume();
}
return child_awaitable.await_resume();
}
};

// Priority tags for SFINAE-based overload resolution
struct get_awaiter_priority_fallback {};
struct get_awaiter_priority_free_op : get_awaiter_priority_fallback {};
struct get_awaiter_priority_member_op : get_awaiter_priority_free_op {};

// Highest priority: member operator co_await
template <typename T>
auto get_awaiter_impl(T&& awaitable, get_awaiter_priority_member_op)
-> decltype(wistd::forward<T>(awaitable).operator co_await())
{
return wistd::forward<T>(awaitable).operator co_await();
}

// Second priority: free operator co_await
template <typename T>
auto get_awaiter_impl(T&& awaitable, get_awaiter_priority_free_op)
-> decltype(operator co_await(wistd::forward<T>(awaitable)))
{
return operator co_await(wistd::forward<T>(awaitable));
}

// Fallback: return the awaitable itself
template <typename T>
T&& get_awaiter_impl(T&& awaitable, get_awaiter_priority_fallback)
{
return wistd::forward<T>(awaitable);
}

template <typename T>
auto get_awaiter(T&& awaitable)
-> decltype(get_awaiter_impl(wistd::forward<T>(awaitable), get_awaiter_priority_member_op{}))
{
return get_awaiter_impl(wistd::forward<T>(awaitable), get_awaiter_priority_member_op{});
}

} // namespace wil::details::coro
/// @endcond

Expand Down Expand Up @@ -700,6 +769,25 @@ template <typename T>
task(com_task<T>&&) -> task<T>;
template <typename T>
com_task(task<T>&&) -> com_task<T>;

// Wrap an awaitable with a suspend/resume watcher; the watcher will be paused while the awaitable
// is suspended and resumed when the coroutine continues. This prevents capturing errors from other
// threads or unrelated code paths while awaiting.
//
// The watcher type must provide:
// bool suspend() - called before suspension, returns true if resume() should be called
// void resume() - called after resumption if suspend() returned true
//
// Usage:
// ThreadFailureCache cache;
// auto result = co_await wil::with_watcher(cache, SomethingAsync());
template <typename TWatcher, typename TAwaitable>
auto with_watcher(TWatcher& watcher, TAwaitable&& awaitable)
{
using awaiter_t = std::decay_t<decltype(details::coro::get_awaiter(std::forward<TAwaitable>(awaitable)))>;
return details::coro::coroutine_withsuspend_awaiter<TWatcher, awaiter_t>{watcher, std::forward<TAwaitable>(awaitable)};
}

} // namespace wil

template <typename T, typename... Args>
Expand Down
38 changes: 33 additions & 5 deletions include/wil/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,24 @@ namespace details
}
}

bool suspend() WI_NOEXCEPT
{
const bool wasWatching = IsWatching();
if (wasWatching)
{
StopWatching();
}
return wasWatching;
}

void resume() WI_NOEXCEPT
{
if (!IsWatching())
{
StartWatching();
}
}

static bool GetThreadContext(
_Inout_ FailureInfo* pFailure,
_In_opt_ ThreadFailureCallbackHolder* pCallback,
Expand Down Expand Up @@ -1093,14 +1111,14 @@ namespace details
{
public:
explicit ThreadFailureCallbackFn(_In_opt_ CallContextInfo* pContext, _Inout_ TLambda&& errorFunction) WI_NOEXCEPT
: m_errorFunction(wistd::move(errorFunction)),
m_callbackHolder(this, pContext)
: m_callbackHolder(this, pContext),
m_errorFunction(wistd::move(errorFunction))
{
}

ThreadFailureCallbackFn(_Inout_ ThreadFailureCallbackFn&& other) WI_NOEXCEPT
: m_errorFunction(wistd::move(other.m_errorFunction)),
m_callbackHolder(this, other.m_callbackHolder.CallContextInfo())
: m_callbackHolder(this, other.m_callbackHolder.CallContextInfo()),
m_errorFunction(wistd::move(other.m_errorFunction))
{
}

Expand All @@ -1109,12 +1127,22 @@ namespace details
return m_errorFunction(failure);
}

bool suspend() WI_NOEXCEPT
{
return m_callbackHolder.suspend();
}

void resume() WI_NOEXCEPT
{
m_callbackHolder.resume();
}

private:
ThreadFailureCallbackFn(_In_ ThreadFailureCallbackFn const&);
ThreadFailureCallbackFn& operator=(_In_ ThreadFailureCallbackFn const&);

TLambda m_errorFunction;
ThreadFailureCallbackHolder m_callbackHolder;
TLambda m_errorFunction;
};

// returns true if telemetry was reported for this error
Expand Down
64 changes: 64 additions & 0 deletions tests/CoroutineTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ wil::task<void> void_task(std::shared_ptr<int> value)
++*value;
co_return;
}

struct resume_new_cpp_thread
{
bool await_ready() noexcept { return false; }
template<typename Handle>
void await_suspend(Handle handle) noexcept
{
std::thread([handle]
{
handle();
}).detach();
}
void await_resume() {}
};

} // namespace

TEST_CASE("CppWinRTTests::SimpleNoCOMTaskTest", "[cppwinrt]")
Expand All @@ -42,4 +57,53 @@ TEST_CASE("CppWinRTTests::SimpleNoCOMTaskTest", "[cppwinrt]")
}).join();
}

TEST_CASE("CoroutineTests::WithWatcherBasic", "[coroutine]")
{
// Test that wil::with_watcher wraps an awaitable and calls suspend()/resume()
// on the watcher object across co_await.
struct mock_watcher
{
int suspend_count = 0;
int resume_count = 0;
bool suspend() noexcept { ++suspend_count; return true; }
void resume() noexcept { ++resume_count; }
};

auto test = [](mock_watcher& watcher) -> wil::task<void>
{
co_await wil::with_watcher(watcher, resume_new_cpp_thread{});
};

std::thread([&] {
mock_watcher watcher;
std::move(test(watcher)).get();
REQUIRE(watcher.suspend_count == 1);
REQUIRE(watcher.resume_count == 1);
}).join();
}

TEST_CASE("CoroutineTests::WithWatcherSuspendReturnsFalse", "[coroutine]")
{
// When suspend() returns false, resume() should not be called.
struct mock_watcher
{
int suspend_count = 0;
int resume_count = 0;
bool suspend() noexcept { ++suspend_count; return false; }
void resume() noexcept { ++resume_count; }
};

auto test = [](mock_watcher& watcher) -> wil::task<void>
{
co_await wil::with_watcher(watcher, resume_new_cpp_thread{});
};

std::thread([&] {
mock_watcher watcher;
std::move(test(watcher)).get();
REQUIRE(watcher.suspend_count == 1);
REQUIRE(watcher.resume_count == 0);
}).join();
}

#endif // coroutines
63 changes: 63 additions & 0 deletions tests/CppWinRTTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,69 @@ TEST_CASE("CppWinRTTests::ResumeForegroundTests", "[cppwinrt]")
}()
.get();
}

namespace
{
struct resume_new_cpp_thread_for_watcher
{
bool await_ready() noexcept { return false; }
template<typename Handle>
void await_suspend(Handle handle) noexcept
{
std::thread([handle]
{
handle();
}).detach();
}
void await_resume() {}
};
} // namespace

TEST_CASE("CppWinRTTests::WithWatcherThreadFailureCallback", "[cppwinrt][coroutine]")
{
// Test that wil::with_watcher correctly pauses/resumes a ThreadFailureCallback across co_await.
auto test = []() -> wil::task<void>
{
auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) { return false; });
co_await wil::with_watcher(watcher, resume_new_cpp_thread_for_watcher{});
};

std::move(test()).get();
}

TEST_CASE("CppWinRTTests::WithWatcherWinRTAction", "[cppwinrt][coroutine]")
{
// Test that wil::with_watcher works with a WinRT IAsyncAction.
auto test = []() -> winrt::Windows::Foundation::IAsyncAction
{
auto tid = ::GetCurrentThreadId();
auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) { return false; });
co_await wil::with_watcher(watcher, winrt::resume_background());
REQUIRE(tid != ::GetCurrentThreadId());
};

test().get();
}

TEST_CASE("CppWinRTTests::WithWatcherWinRTOperation", "[cppwinrt][coroutine]")
{
// Test that wil::with_watcher works with a WinRT IAsyncOperation.
auto inner = []() -> winrt::Windows::Foundation::IAsyncOperation<winrt::hstring>
{
co_await winrt::resume_background();
co_return winrt::hstring(L"kittens");
};

auto test = [&inner]() -> winrt::Windows::Foundation::IAsyncAction
{
auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) { return false; });
auto result = co_await wil::with_watcher(watcher, inner());
REQUIRE(result == L"kittens");
};

test().get();
}

#endif // coroutines

TEST_CASE("CppWinRTTests::ThrownExceptionWithMessage", "[cppwinrt]")
Expand Down
33 changes: 33 additions & 0 deletions tests/TraceLoggingTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,36 @@
// Just verify that Tracelogging.h compiles.
#define PROVIDER_CLASS_NAME TestProvider
#include "TraceLoggingTests.h"

#include "catch.hpp"
#include "common.h"

TEST_CASE("TraceLoggingTests::ActivitySuspendResume", "[tracelogging]")
{
// Test that Activity classes implement the suspend/resume interface for coroutine watchers.
// This interface is used by wil::with_watcher() to pause error watching during co_await.
auto activity = TestProvider::TraceloggingActivity::Start();

// Initially watching after Start()
REQUIRE(activity.IsRunning());

// suspend() should return true (was watching) and stop watching
bool wasWatching = activity.suspend();
REQUIRE(wasWatching);
REQUIRE_FALSE(activity.IsWatching());

// Calling suspend() again should return false (wasn't watching)
wasWatching = activity.suspend();
REQUIRE_FALSE(wasWatching);
REQUIRE_FALSE(activity.IsWatching());

// resume() should restart watching
activity.resume();
REQUIRE(activity.IsWatching());

// Calling resume() when already watching is safe
activity.resume();
REQUIRE(activity.IsWatching());

activity.Stop();
}
Loading